tensoflow随笔——softmax和交叉熵

softmax函数

softmax函数接收一个N维向量作为输入,然后把每一维的值转换到(0, 1)之间的一个实数。假设模型全连接网络输出为a,有C个类别,则输出为a1,a2,...,aC,对于每个样本,属于类别i的输出概率为:

属于各个类别的概率和为1。

贴一张形象的说明图:

如图将原来输入的3,1,-3通过softmax函数的作用,映射成为(0,1)的值,而这些值的累和为1(满足概率的性质),我们可以将它理解成概率,在最后选取输出结点的时候,我们就可以选取概率值最大的结点,作为我们的预测目标。

softmax导数

对softmax求导即:

当i = j 时:

当i ≠ j时:

softmax数值稳定性

传入数据[1, 2, 3, 4, 5]时

传入数据[1000, 2000, 3000, 4000, 5000]时

导致输出是nan的原因是exp(x)对较大的数求指数溢出的问题。

一般的做法是额外加上一个非零常数,使所有的输入在0的附近。

比如:

def softmax(x):
    shift_x = x - np.max(x)
    exp_x = np.exp(shift_x)
    return exp_x / np.sum(exp_x)

交叉熵:用来判定实际的输出与期望的输出的接近程度!

刻画的是实际输出与期望输出的距离,也就是交叉熵的值越小,两个概率分布就越接近,假设概率分布p为期望输出,概率分布q为实际输出,H(p,q)为交叉熵,则:

或者:

Tensorflow中对交叉熵的计算可以采用两种方式

1.手动实现:

import tensorflow as tf

input = tf.placeholder(dtype=tf.float32, shape=[None, 28*28])
output = tf.placeholder(dtype=tf.float32, shape=[None, 10])

w_fc1 = tf.Variable(tf.truncated_normal([28*28, 1024], stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
h_fc1 = tf.matmul(input, w_fc1) + b_fc1

w_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
logits = tf.nn.softmax(tf.matmul(h_fc1, w_fc2) + b_fc2)

cross_entropy = -tf.reduce_sum(output * tf.log(logits))

output是one-hot类型的实际输出,logits是对全连接的输出用softmax进行转换为概率值的预测,最后通过cross_entropy = -tf.reduce_sum(label * tf.log(y))求出交叉熵的。

2.tf.nn.softmax_cross_entropy_with_logits:

tensorflow已经对softmax和交叉熵进行了封装

import tensorflow as tf

input = tf.placeholder(dtype=tf.float32, shape=[None, 28*28])
output = tf.placeholder(dtype=tf.float32, shape=[None, 10])

w_fc1 = tf.Variable(tf.truncated_normal([28*28, 1024], stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
h_fc1 = tf.matmul(input, w_fc1) + b_fc1

w_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
logits = tf.matmul(h_fc1, w_fc2) + b_fc2

cross_entropy = -tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=output, logits=logits))

函数的参数logits在函数内会用softmax进行处理,所以传进来时不能是softmax的输出。

官方的封装函数会在内部处理数值不稳定等问题,如果选择方法1,需要自己在softmax函数里面添加trick。

猜你喜欢

转载自blog.csdn.net/neil3611244/article/details/81279501