tf神经网络优化常用函数范式一览(损失函数loss,学习率learning_rate,滑动平均ema,正则化regularization)

来源: 学习北大曹健老师Tensorflow课程后做的学习笔记,整理后在此分享给大家. 之前是抄在本子上的, 可能存在一些抄写错误或是标记缺失, 若发现了可以评论告诉我.谢谢.

一. 损失函数(loss):预测值y与已知答案y_的差距

神经网络优化目标即找到适合的w以减小loss, 有三种减小loss的方法

1. 均方误差mse(Mean Squared Error)
2. 自定义损失函数
3. 交叉熵ce(Cross Entropy)

1. 均方误差mse
模型 : mse
使用:

loss_mse = tf.reduce_mean(tf.square(y-y_))

+

**默认预测结果偏低偏高时结果相同,解决利益最大问题时无法求最优**

2. 自定义损失函数

**如预测商品销量,若利润!=成本,则mse产生的loss无法利益最大**

使用:

loss=tf.reduce_sum(tf.where(tf.greater(y,y_),COST(y-y_),PROFIT(y_-y)))

3. 交叉熵 :表示两个概率分布之间的距离
模型 : ce

使用:

ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_; 1))
cem = tf.reduce_mean(ce)

二. 学习率(learning_rate):参数每次更新幅度

W_(n+1) = W_(n) - learning_rate ▽

W_(n+1)更新后参数 
W_(n)当前参数                
learning_rate学习率         
▽损失函数的梯度(导数)

介绍一下指数衰减学习率, 可以根据轮数动态更新学习率(先快后慢,效率更高)

#指数衰减学习率
#定义计数器
global_step = tf.Variable(0, trainable=False)
#定义指数衰减学习率
learning_rate = tf.train_exponential_decay(LEARNING_RATE_BASE, global_step, LEARNING_RATE_STEP, LEARNING_RATE_DECAY, staircase = True)
#staircase=True即global_stop/LEARNING_RATE_STEP取整,学习率阶梯衰减/False学习率平滑下降

三. 滑动平均ema(影子)

模型: 影子 = 衰减率 * 影子 + (1-衰减率) * 参数

#衰减率 = Min{MOVING_AVERAGE_DECAY,  (1+轮数)/(10+轮数)}

使用:

#定义ema
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
#定义ema节点,把待训练参数汇总成列表
ema_op = ema.apply(tf.trainable_variables())
#工程中常把计算滑动平均和训练过程绑定运行,使他们合成一个训练节点
with tf.contral_dependencies([train_step, ema_op]):
    train_op = tf.no_op(name='train')

四. 正则化缓解过拟合

正则化在损失函数中引入模型复杂度指标, 利用给W加权值, 弱化了训练数据的噪声(一般不正则化b)

loss = loss(y与y_) + 超参数REGULARIZER * loss(w)

loss()为参数的损失函数
超参数REGULARIZER给出参数w在总loss中的比例(正则化权重)
w为需要正则化的参数
loss(w)的两种算法

L1正则化

loss(w) = tf.contrib.layers.l1_regularizer(REGULARIZER)(w)

L2正则化

loss(w) = tf.contrib.layers.l2_regularizer(REGULARIZER)(w)

使用:

#把内容加到集合对应位置做加法
tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
#所有值相加
loss = cem + tf.add_n(tf.get_collection('losses'))

                                Copyright:dolor_059

猜你喜欢

转载自blog.csdn.net/dolor_059/article/details/82086407