Función de pérdida de notas de estudio de Tensorflow

1. Función de pérdida de error cuadrático medio
La derivada parcial de la función de pérdida al peso es proporcional al gradiente de la función de activación. Si la función de activación es lineal, se puede utilizar esta función de pérdida. Si la función de activación es una función sigmoidea, esta función de pérdida no es adecuada por las siguientes razones:
si esperamos que el valor de salida sea 1, A está lejos de 1 , El gradiente de la función de activación también es más grande y el paso de ajuste del optimizador también es más grande; B está más cerca de 1, el gradiente de la función de activación también es más pequeño y el paso de ajuste del optimizador también es más pequeño, lo cual es razonable.
Si esperamos que el valor de salida sea 0, A está más lejos de 0, el gradiente de la función de activación también es mayor y el paso de ajuste del optimizador también es mayor; B está más lejos de 0, el gradiente de la función de activación es más pequeño y el ajuste del optimizador El paso también es Si es más pequeño, tomará mucho tiempo para que el valor de salida se ajuste a 0, lo cual no es razonable.
Inserte la descripción de la imagen aquí
2. Función de pérdida de entropía cruzada
La derivada parcial de la función de pérdida al peso no tiene nada que ver con el gradiente de la función de activación y es proporcional a la diferencia entre el valor real del valor predicho. Esta función de pérdida se puede utilizar independientemente de si la función de activación es lineal o sigmoidea. Cuando la desviación entre el valor predicho y el valor real es grande, el optimizador ajusta el paso más grande, y cuando la desviación entre el valor predicho y el valor real es pequeña, el optimizador ajusta el paso más pequeño, lo cual es razonable.
3. Función de pérdida de probabilidad
logarítmica Para problemas de clasificación, la neurona de salida es la función softmax. En este momento, la función de pérdida comúnmente utilizada es la función de pérdida de probabilidad logarítmica. La función de pérdida de probabilidad logarítmica se combina con la función softmax, y la función de pérdida de entropía cruzada se combina con la función sigmoide Estas dos combinaciones son muy similares. Para dos problemas de clasificación, la función de pérdida de probabilidad logarítmica se puede simplificar a la función de pérdida de entropía cruzada.

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 获取数据集
# one_hot设置为True,将标签数据转化为0/1,如[1,0,0,0,0,0,0,0,0,0]
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)

# 定义一个批次的大小
batch_size=100
n_batch=mnist.train.num_examples//batch_size

# 定义两个placeholder
# 行数值为None,None可以取任意数,本例中将取值100,即取决于pitch_size
# 列数值为784,因为输入图像尺寸已由28*28转换为1*784
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])

# 定义两个变量
w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))

# 定义一个神经网络
# softmax的作用是将tf.matmul(x,w)+b的结果转换为概率值,举例如下:
# [9,2,1,1,2,1,1,2,1,1]
# [0.99527,0.00091,0.00033,0.00033,0.00091,0.00033,0.00033,0.00091,0.00033,0.00033]
prediction=tf.nn.softmax(tf.matmul(x,w)+b)

# 定义损失函数
# 由于输出神经元使用softmax函数,交叉熵损失函数比均方误差损失函数收敛速度更快
# loss=tf.reduce_mean(tf.square(y-prediction))
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))

# 定义优化器
optimizer=tf.train.GradientDescentOptimizer(0.2)

# 定义模型,优化器通过调整loss里的参数,使loss不断减小
train=optimizer.minimize(loss)

# 统计准确率
# tf.argmax返回第一个参数中最大值的下标
# tf.equal比较两个参数是否相等,返回True或False
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# tf.cast将布尔类型转换为浮点类型
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	# epoch为周期数,所有批次训练完为一个周期
	for epoch in range(20):
		for batch in range(n_batch):
			# 每次取出batch_size条数据进行训练
			batch_xs,batch_ys=mnist.train.next_batch(batch_size)
			sess.run(train,feed_dict={
    
    x:batch_xs,y:batch_ys})
		acc = sess.run(accuracy,feed_dict={
    
    x:mnist.test.images,y:mnist.test.labels})
		print('epoch=',epoch,' ','acc=',acc)

resultado de la operación:
Inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/wxsy024680/article/details/114535778
Recomendado
Clasificación