Fonction de perte de notes de l'étude Tensorflow

1. Fonction de perte d'erreur quadratique moyenne
La dérivée partielle de la fonction de perte par rapport au poids est proportionnelle au gradient de la fonction d'activation. Si la fonction d'activation est linéaire, cette fonction de perte peut être utilisée. Si la fonction d'activation est une fonction sigmoïde, cette fonction de perte ne convient pas pour les raisons suivantes:
si l'on s'attend à ce que la valeur de sortie soit 1, A est loin de 1 , Le gradient de la fonction d'activation est également plus grand et l'étape de réglage de l'optimiseur est également plus grande; B est plus proche de 1, le gradient de la fonction d'activation est également plus petit et l'étape de réglage de l'optimiseur est également plus petite, ce qui est raisonnable.
Si nous nous attendons à ce que la valeur de sortie soit 0, A est plus éloigné de 0, le gradient de la fonction d'activation est également plus grand et le pas de réglage de l'optimiseur est également plus grand; B est plus éloigné de 0, le gradient de la fonction d'activation est plus petit et l'ajustement de l'optimiseur Le pas est également S'il est plus petit, il faudra beaucoup de temps pour que la valeur de sortie s'ajuste à 0, ce qui est déraisonnable.
Insérez la description de l'image ici
2. Fonction de perte d'entropie croisée
La dérivée partielle de la fonction de perte au poids n'a rien à voir avec le gradient de la fonction d'activation et est proportionnelle à la différence entre la valeur réelle de la valeur prédite. Cette fonction de perte peut être utilisée indépendamment du fait que la fonction d'activation soit une fonction linéaire ou sigmoïde. Lorsque l'écart entre la valeur prédite et la valeur vraie est important, l'optimiseur ajuste le pas plus grand, et lorsque l'écart entre la valeur prédite et la valeur vraie est faible, l'optimiseur ajuste le pas plus petit, ce qui est raisonnable.
3. Fonction de perte de log-vraisemblance
Pour les problèmes de classification, le neurone de sortie est la fonction softmax. À ce stade, la fonction de perte couramment utilisée est la fonction de perte de log-vraisemblance. La fonction de perte log-vraisemblable est combinée avec la fonction softmax et la fonction de perte d'entropie croisée est combinée avec la fonction sigmoïde. Ces deux combinaisons sont très similaires. Pour deux problèmes de classification, la fonction de perte log-vraisemblable peut être simplifiée à la fonction de perte d'entropie croisée.

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 获取数据集
# one_hot设置为True,将标签数据转化为0/1,如[1,0,0,0,0,0,0,0,0,0]
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)

# 定义一个批次的大小
batch_size=100
n_batch=mnist.train.num_examples//batch_size

# 定义两个placeholder
# 行数值为None,None可以取任意数,本例中将取值100,即取决于pitch_size
# 列数值为784,因为输入图像尺寸已由28*28转换为1*784
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])

# 定义两个变量
w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))

# 定义一个神经网络
# softmax的作用是将tf.matmul(x,w)+b的结果转换为概率值,举例如下:
# [9,2,1,1,2,1,1,2,1,1]
# [0.99527,0.00091,0.00033,0.00033,0.00091,0.00033,0.00033,0.00091,0.00033,0.00033]
prediction=tf.nn.softmax(tf.matmul(x,w)+b)

# 定义损失函数
# 由于输出神经元使用softmax函数,交叉熵损失函数比均方误差损失函数收敛速度更快
# loss=tf.reduce_mean(tf.square(y-prediction))
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))

# 定义优化器
optimizer=tf.train.GradientDescentOptimizer(0.2)

# 定义模型,优化器通过调整loss里的参数,使loss不断减小
train=optimizer.minimize(loss)

# 统计准确率
# tf.argmax返回第一个参数中最大值的下标
# tf.equal比较两个参数是否相等,返回True或False
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# tf.cast将布尔类型转换为浮点类型
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	# epoch为周期数,所有批次训练完为一个周期
	for epoch in range(20):
		for batch in range(n_batch):
			# 每次取出batch_size条数据进行训练
			batch_xs,batch_ys=mnist.train.next_batch(batch_size)
			sess.run(train,feed_dict={
    
    x:batch_xs,y:batch_ys})
		acc = sess.run(accuracy,feed_dict={
    
    x:mnist.test.images,y:mnist.test.labels})
		print('epoch=',epoch,' ','acc=',acc)

résultat de l'opération:
Insérez la description de l'image ici

Je suppose que tu aimes

Origine blog.csdn.net/wxsy024680/article/details/114535778
conseillé
Classement