tf.train.exponential_decay( ) & tf.train.piecewise_constant( )

1. decayed_learning_rate = tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None)

模型训练过程中需要下降学习率(learning rate),该函数提供指数衰减函数来初始化学习率。计算公式为:

decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)

若 staircase='True', 则global_step / decay_steps是整数除法,学习率按阶梯函数下降。

例:每100000steps 衰减基数0.96

 global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = 0.1
    learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                               100000, 0.96, staircase=True)
    # Passing global_step to minimize() will increment it at each step.
    learning_step = (
        tf.train.GradientDescentOptimizer(learning_rate)
        .minimize(...my loss..., global_step=global_step)
    )

参数:learning_rate: 初始学习率,类型为:'float32' or 'float64'标量、 'Tensor'

         global_step: 用于衰减计算,不能为负数,类型为:'int32' or 'int64'标量、 'Tensor'

         decay_step: 用于计算,必须为正数,类型为:'int32' or 'int64'标量、 'Tensor'

         staircase: 布尔值,'True'表示学习率按离散间隔衰减,'False'表示每一step衰减学习率

        name: 操作名称, 类型为:String,默认值'ExponentialDecay'.

其中global_step值根据train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) 中minimize( )对global_step的更新完成。

2. tf.train.piecewise_constant(x, boundaries, values, name=None)

根据boundaries分阶段,每阶段value由values指定的分段常数

例:前100000steps学习率为0.1,100000~110000steps学习率为0.01,剩余steps为0.001

global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [0.1, 0.01, 0.001]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

global_step<=100000,values[0]; global_step>100000 & global_step<=110000,value[1];global_step>110000,value[2]

参数:x: 类型为:'float32'  'float64'  `uint8` `int8` `int16` `int32`  `int64`

        boundaries: 列表,严格递增,元素类型同x

        values: 列表,代表boundaries所定义的区间里的value,其元素个数比boundaries多一个

        name: 操作名称, 类型为:String,默认值 'PiecewiseConstant'

猜你喜欢

转载自blog.csdn.net/ghy_111/article/details/80591314
今日推荐