tensorflow中的梯度下降函数

tensorflow中的梯度下降函数如下:

     在训练过程中先实例化一个优化函数如tf.train.GradientDescentOptimizer,并基于一定的学习率进行梯度优化训练。
                        optimizer = tf.train.GradientDescentOptimizer(learning_rate)

       接着使用一个minimize()操作,里面传入损失值节点loss,再启动一个外层的循环,优化器就会按照循环的次数一次次沿着loss最小值的方向优化参数了。整个过程中的求导和反向传播操作,都是在优化器里自动完成的。目前比较常用的优化器为Adam优化器。

退化学习率---在训练的速度与精度之间找到平衡    

    前面提到的 learning_rate 就代表是学习率。

    设置学习率的大小是在训练的速度与精度之间一个平衡。

    1)如果  learning_rate 的值较大,则训练速度会提升,但是精度不够

   2)如果  learning_rate 的值较小,精度提升了,训练又太耗时间

退化学习率又叫做学习率衰减,是在训练过程中对于学习率的大和小的优点都能够为我们所用,也就是训练刚开始使用大的学习率加快速度,训练到一定程度后,使用小的学习率来提高精度。

def exponential_decay(learning_rate, 
                      global_step,
                      decay_steps,
                      decay_rate,
                      staircase=False,     # 为 True 时没有衰减功能

                      name=None)

学习率的衰减速度有 global_step 和 decay_steps来决定。

decay_learning_rate = learning_rate * decay_rate ^ ( global_step / decay_steps)

learn_rate = tf.train.exponential_decay(initial_learning_rate, 
                                        global_step=global_step, 
                                        decay_rate=0.9, decay_steps=10)

上面代码的意思是:当前迭代达到 global_step 步,学习率每一步都按照每 decay_steps 步缩小到 decay_rate 的速度衰减。

退化学习率举例:

import tensorflow as tf

global_step = tf.Variable(0, trainable=False)

initial_learning_rate = .1
learn_rate = tf.train.exponential_decay(initial_learning_rate,
                                        global_step=global_step,
                                        decay_rate=0.9, decay_steps=10)
opt = tf.train.GradientDescentOptimizer(learn_rate)
add_global = global_step.assign_add(1)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(learn_rate))
    for _ in range(20):
        g, rate = sess.run([add_global, learn_rate])
        print(g, rate)

运行结果:

0.1
1 0.1
2 0.09791484
3 0.09791484
4 0.095873155
5 0.095873155
6 0.09387404
7 0.092890166
8 0.092890166
9 0.09191661
10 0.089999996
11 0.089999996
12 0.08905673
13 0.087199755
14 0.08628584
15 0.0853815
16 0.08448663
17 0.08360115
18 0.08272495
19 0.08185793
20 0.08099999


猜你喜欢

转载自blog.csdn.net/qq_42413820/article/details/80939805
今日推荐