tensorflow API _ 3 (tf.train.polynomial_decay)

学习率的三种调整方式:
固定的,指数的,多项式的

def _configure_learning_rate(num_samples_per_epoch, global_step):
"""Configures the learning rate.

Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.

Returns:
A `Tensor` representing the learning rate.

Raises:
ValueError: if
"""
decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
FLAGS.num_epochs_per_decay)
if FLAGS.sync_replicas:
decay_steps /= FLAGS.replicas_to_aggregate

if FLAGS.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif FLAGS.learning_rate_decay_type == 'fixed':
return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
elif FLAGS.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
else:
raise ValueError('learning_rate_decay_type [%s] was not recognized',
FLAGS.learning_rate_decay_type)

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324736505&siteId=291194637