版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/bqw18744018044/article/details/83218796
关于交叉熵的文章
关于softmax函数的文章
import numpy as np
import tensorflow as tf
sess = tf.InteractiveSession()
一、交叉熵
1.多分类中的Softmax函数
在多分类的神经网络中,通常在最后一层接一个softmax层。对于n分类问题,softmax层就有n个结点,每个结点输出的就是该类别的概率.
例如5分类的问题,神经网络可能会输出[0,0.6,0.3,0,0.1],由于该样本属于第2类的概率最大,为0.6,故属于2类。
2.交叉熵损失函数
交叉熵是用来衡量两个概率分布p和q的差异程度的,其定义为:
交叉熵通过和softmax函数连在一起使用,TensorFlow对这两个功能进行了统一封装,提供了tf.nn.softmax_cross_entropy_with_logits
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y)
print(cross_entropy.eval())
1.2369158210025892
二、均方误差
回归问题通常使用均方误差损失函数,其定义为:
mse = tf.reduce_mean(tf.square(y_-y))
print(mse.eval())
0.052000000000000005
三、自定义损失函数
1.自定义的交叉熵损失函数
# 产生模拟数据
y_ = np.array([0,1,0,0,0]) # 真实标签
y = np.array([0,0.6,0.3,0,0.1]) # 模型输出的预测值
# 计算交叉熵
# tf.reduce_mean():按默认轴求均值
# tf.clip_by_value(t, clip_value_min, clip_value_max):将t值卡在clip_value_min到clip_value_max的区间内
cross_entropy = -tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,1e-10,1.0)))
print(cross_entropy.eval())
2.自定义损失函数
import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32,shape=(None,2),name='x-input')
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32,shape=(None,1),name='y-input')
# 定义单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2,1],stddev=1,seed=1))
y = tf.matmul(x,w1)
######定义自己的loss#######
loss_less = 10
loss_more = 1
# tf.greater()函数相当于max()函数,tf.where()函数相当于二元运算符“:?”
loss = tf.reduce_sum(tf.where(tf.greater(y,y_),(y-y_)*loss_more,(y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
# 通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size,2)
Y = [[x1+x2+rdm.rand()/10.0-0.05] for (x1,x2) in X]
# 训练神经网络
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size)%dataset_size
end = min(start+batch_size,dataset_size)
sess.run(train_step,feed_dict={x:X[start:end],y_:Y[start:end]})
print(sess.run(w1))
[[1.019347 ]
[1.0428089]]