Tensorflow之神经网络学习率

1. 学习率引入

在梯度下降中,有一个系数,叫 学习率(learning rage)。它决定了沿着让代价函数下降程度最大的方向向下走的步伐有多大。学习率的设置,对神经网络训练效果至关重要。如果学习速率太小,结果就是只能这样像小宝宝一样一点点地挪动,去努力接近最低点,这样就需要很多步才能到达最低点,所以如果太小的话,可能会很慢,因为它会一点点挪动,它会需要很多步才能到达全局最低点。
如果太大,那么梯度下降法可能会越过最低点,甚至可能无法收敛,下一次迭代又移动了一大步,越过一次,又越过一次,一次次越过最低点,直到你发现实际上离最低点越来越远,所以,如果太大,它会导致无法收敛,甚至发散。【此处引用某爱好者的机器学习个人笔记】
因此,学习率既不能过大,也不能过小。为了解决设置学习率的问题,Tensorflow提供了一种灵活的设置方法,叫 指数衰减法(exponential decay method);其接口为tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None)。

2. exponential_decay接口定义及实现

2.1接口定义

def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
                      staircase=False, name=None):

1)learning_rate:设定的初始学习率;
2)global_step:整型值,当前迭代次数;
3)decay_steps:衰减速度;
4)decay_rate:衰减系数;
5)staircase:不同的衰减方式;默认值为False,此时学习率随迭代轮次变化呈连续变化的趋势;当设置为True时,global_step/decay_step被取整,此时学习率将成为分段函数形式变化趋势。

2.1接口实现

函数表达式:
在这里插入图片描述

简单的说,指数衰减学习率就是变量为global_step的函数。一般learning_rate、decay_steps、decay_rate在训练全程都被固定。

3. exponential_decay示例验证

在示例中,取learning_rate = 0.1,decay_steps = 1, decay_rate=0.96

import math
import tensorflow as tf

TRAINING_STEPS = 100
global_step = tf.Variable(0)
LEARNING_RATE = tf.train.exponential_decay(0.1, global_step, 1, 0.96, staircase=True)

x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
y = tf.square(x)
train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(y, global_step=global_step)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(TRAINING_STEPS):
        sess.run(train_op)
        if i % 10 == 0:
            LEARNING_RATE_value = sess.run(LEARNING_RATE)
            x_value = sess.run(x)
            steps = sess.run(global_step)
            learning_rate_basicAlg = 0.1 * (math.pow(0.96, steps/1))
            print("After %s iteration(s): x%s is %f, learning rate is %f, \
                   global_step %d, learning_rate_basicAlg %f."% (i+1, i+1, x_value, LEARNING_RATE_value, steps, learning_rate_basicAlg))

结果打印:
在这里插入图片描述
从打印结果看,在每次迭代中,基本算法learning_rate_basicAlg 的值和exponential_decay接口的值一致。

猜你喜欢

转载自blog.csdn.net/duanyuwangyuyan/article/details/108483827