Xiaobai learns Pytorch series – Torch.optim API Scheduler (3)
torch.optim.lr_scheduler
Several methods are provided to adjust the learning rate according to the number of epochs.
torch.optim.lr_scheduler.ReduceLROnPlateau
Allows dynamic learning rate reduction based on some validation measure.
The learning rate schedule should be applied after the optimizer is updated; for example, you should write your code like this
'
Demo:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler.step()
Most learning rate schedulers can be called back-to-back schedulers (also known as chain schedulers). The result is that each scheduler applies the learning rate obtained by the previous scheduler one by one.
Demo:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler1.step()
scheduler2.step()
In many places in the documentation, we will use the following template to refer to scheduler algorithms.
scheduler = ...
for epoch in range(100):
train(...)
validate(...)
scheduler.step()
Scheduler source code analysis
Reference: https://zhuanlan.zhihu.com/p/346205754?utm_medium=social&utm_oi=73844937195520 The
main logical function of the learning rate adjustment class is to calculate the learning rate of each epoch
parameter group, update the value optimizer
in the corresponding parameter group lr
, and apply it in optimizer
Gradient update of learnable parameters. The parent class of all learning rate adjustment strategy classes is torch.optim.lr_scheduler._LRScheduler
, and the base class _LRScheduler
defines the following methods:
- step(epoch=None): Common to subclasses
- get_lr(): Subclasses need to implement
- get_last_lr(): Common to subclasses
- print_lr(is_verbose, group, lr, epoch=None): display lr adjustment information
- state_dict(): Subclasses may override
- load_state_dict(state_dict): Subclasses may override
Initialize init :
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1, verbose=False):
.......
self.optimizer = optimizer
# Initialize epoch and base learning rates
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
self.last_epoch = last_epoch
........
self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer._step_count = 0
self._step_count = 0
self.verbose = verbose
self.step()
Initialization parameters:
optimizer
is an instance of the optimizerlast_epoch
It is the lastepoch
timeindex
. The default value is -1, which represents the initial training model. At this time,optimizer
the initial learning rate will be set for each parameter group in itinitial_lr
.
If last_epoch
the input value is greater than -1, it means to continue the last training from a epoch
certain , and the required optimizer
parameter group has initial_lr
initial learning rate information. with_counter
The function inside the initialization function is mainly to ensure lr_scheduler.step()
that it is optimizer.step()
called later. Note that it is called in the last step of the __init__ function self.step()
, that is, the method _LRScheduler
has been called once during initialization step()
.
step
When the model completes a epoch
training , it needs to call step()
the method. After the self-increment in this method , call the method last_epoch
implemented by the subclass in the internal context manager class to obtain the learning rate of each parameter group at this time, and update it to the attribute of , and finally record the last adjusted learning rate , this attribute will be returned in the method. In this method, the internal class of the context management function is used , and the instance object adds attributes, which can be used in subclasses.get_lr()
epoch
optimizer
param_groups
self._last_lr
get_last_lr()
_enable_get_lr_call
_get_lr_called_within_step
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
self._step_count += 1
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = self._get_closed_form_lr()
else:
values = self.get_lr()
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self.print_lr(self.verbose, i, lr, epoch)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
get_lr
get_lr()
The method is an abstract method, which defines the interface for updating the learning rate strategy. Different subclasses will have different implementations after inheritance. Its return value is [lr1, lr2, ...] structure
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
get_last_lr
def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
"""
return self._last_lr
print_lr
print_lr(is_verbose, group, lr, epoch=None)): This method provides the function of displaying lr adjustment information
def print_lr(self, is_verbose, group, lr, epoch=None):
"""Display the current learning rate.
"""
if is_verbose:
if epoch is None:
print('Adjusting learning rate'
' of group {} to {:.4e}.'.format(group, lr))
else:
epoch_str = ("%.2f" if isinstance(epoch, float) else
"%.5d") % epoch
print('Epoch {}: adjusting learning rate'
' of group {} to {:.4e}.'.format(epoch_str, group, lr))
other interface
state_dict()
: Return all attribute information of the current instance in the form of dictionary dictself.optimizer
exceptload_state_dict(state_dict)
: Reload previously saved state information