Tensorflow中的学习率

Tensorflow中的学习率


学习率(learning_rate): 表示了每次参数更新的幅度大小。学习率过大,会导致待优化的参数在最小值附近波动,不收敛;学习率过小,会导致待优化的参数收敛缓慢。在训练过程中,参数的更新向着损失函数梯度下降的方向。
参数的更新公式为: w n + 1 = w n l e a r n i n g _ r a t e {w_{n + 1}} = {w_n} - learning\_rate\nabla
  假设损失函数为 l o s s = ( w + 1 ) 2 loss = {(w + 1)^2} 。梯度是损失函数loss的导数为 = 2 w + 2 \nabla = 2w + 2 。如参数初值为5,学习率为0.2,则参数和损失函数更新如下:

在这里插入图片描述
损失函数 l o s s = ( w + 1 ) 2 loss = {(w + 1)^2} 的图像为:
在这里插入图片描述
  图中,损失函数loss的最小值会在(-1,0)处得到,此时损失函数的倒数为0,得到最终参数w=-1。代码如下:

import tensorflow as tf

w = tf.Variable(tf.constant(5, dtype = tf.float32))
loss = tf.square(w+1)
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    for i in range(40):
        sess.run(train_step)
        w_val = sess.run(w)
        loss_val = sess.run(loss)
        print("After %s steps:w is %f, loss is %f."%(i, w_val, loss_val))

运行结果:
在这里插入图片描述
由结果可知,随着损失函数值的减小,w无限趋近于-1,模型计算推测出最有参数w=-1。
注意: 当学习率过大时,会导致待优化的参数在最小值附近波动,不收敛;学习率过小,会导致待优化的参数收敛缓慢。
①将上述程序的学习率改为1时,结果如下:
在这里插入图片描述
  运行结果中,损失函数loss值并没有收敛,而是在5和-7之间波动。
②当把学习率修改为0.0001时,实验结果如下:
在这里插入图片描述
  结果中,损失函数loss值缓慢下降,w值也在小幅度变化,收敛缓慢。
指数衰减学习率: 为了解决设定学习率的问题,Tensorflow提供了一种更加灵活的学习率设置方法——指数衰减法。tf.train.exponential_decay函数实现了指数衰减学习率。通过这个函数可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。exponential_decay函数会指数级地减小学习率,学习率计算公式如下: L e a r n i n g _ r a t e = L E A R N I N G _ R A T E _ B A S E × L E A N I N G _ R A T E _ D E C A Y × g l o b a l _ s t e p L E A R N I N G _ R A T E _ B A T C H _ S I Z E Learning\_rate = LEARNING\_RATE\_BASE \times LEANING\_RATE\_DECAY \times \frac{{global\_step}}{{LEARNING\_RATE\_BATCH\_SIZE}} exponential_decay函数实现了以下代码的功能:

decayed_learning_rate=learning_rate*decay_rate^(global_step/decay_steps)

  其中decayed_learning_rate为每一轮优化时使用的学习率,learning_rate为实现设定的初始学习率,decay_rate为衰减系数,decay_steps为衰减速度。下图显示出随着迭代次数的增加,学习率逐步降低的过程。tf.train.exponential_decay函数可以通过设置参数staircase选择不同的的衰减方式。staircase的默认值为False,这时学习率随迭代论数变化的趋势如下灰色曲线,当staircase的值被设置为True时,global_step/decay_steps会被转化为整数。这使得学习率成为一个阶梯函数。下图中黑色曲线显示了阶梯状的学习率。
decay_steps通常代表了完整的使用一遍训练数据所需要的迭代论述。这个迭代论数也就是总训练样本除以每一个batch中的训练样本数。这种设置的常用场景是每完整地过完一遍训练数据,学习率就减小一次。这可以使得训练数据集中的所有数据对模型训练有相等的作用。当使用连续的指数衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响就很小了。
在这里插入图片描述
下面的代码示范如何在Tensirflow中使用t.train.exponential_decay函数:

global_step = tf.Variable(0)
# 通过exponential_decay函数生成学习率。
learning_rate = tf.train.exponential_decay(0.1, global_step, 100, 0.96, staircase=True)

# 使用指数衰减的学习率。在minimize函数中传入global_step将自动更新。
# global_step参数,从而使得学习率得到相应的更新。
learning_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(...my loss...,global_step=global_step)

  程序中设定了初始学习率为0.1,因为指定了staircase=True,所以每训练100轮后学习率乘以0.96。一般来说初始学习率、衰减系数和衰减速度都是根据经验设置的。而且损失函数下降的速度和迭代结束之后总损失的大小没有必然的联系。
  下面的程序中模型训练过程不设定固定的学习率,使用指数衰减学习率进行训练。其中,学习率初值设置为0.1,学习率衰减率设置为0.99,VATCH_SIZE设置为1.

import tensorflow as tf

LEARNING_RATE_BASE = 0.1 # 最初学习率
LEARNING_RATE_DECAY = 0.99 # 学习率衰减率
LEARNING_RATE_STEP = 1 # 喂入多少论BATCH_SIZE后,更新一次学习率,一般设为:总样本数//BATCH_SIZE

# 运行了几轮BATCH_SIZE的计数器,初值给0,设为不被训练
global_step = tf.Variable(0, trainable=False)
# 定义指数下降学习率
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, LEARNING_RATE_STEP, LEARNING_RATE_DECAY, staircase=True)
# 定义待优化参数,初值给10
w = tf.Variable(tf.constant(5, dtype=tf.float32))
# 定义损失函数loss
loss = tf.square(w+1)
# 定义反向传播方法
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)

# 生成会话,训练40轮
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    for i in range(40):
        sess.run(train_step)
        learning_rate_val = sess.run(learning_rate)
        global_step_val = sess.run(global_step)
        w_val = sess.run(w)
        loss_val = sess.run(loss)
        print("After %s steps:global_step is %f, w is %f,learning rate is %f, loss is %f"%(i, global_step_val, w_val, learning_rate_val, loss_val))

运算结果如下:
在这里插入图片描述

发布了19 篇原创文章 · 获赞 25 · 访问量 2440

猜你喜欢

转载自blog.csdn.net/fly975247003/article/details/100586721
今日推荐