Xiaobai learns Pytorch series--Torch.optim API Scheduler (3)

Xiaobai learns Pytorch series – Torch.optim API Scheduler (3)


torch.optim.lr_schedulerSeveral methods are provided to adjust the learning rate according to the number of epochs.
torch.optim.lr_scheduler.ReduceLROnPlateau Allows dynamic learning rate reduction based on some validation measure.
The learning rate schedule should be applied after the optimizer is updated; for example, you should write your code like this
'
Demo:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

Most learning rate schedulers can be called back-to-back schedulers (also known as chain schedulers). The result is that each scheduler applies the learning rate obtained by the previous scheduler one by one.

Demo:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler1.step()
    scheduler2.step()

In many places in the documentation, we will use the following template to refer to scheduler algorithms.

scheduler = ...
for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step()

Scheduler source code analysis

Reference: https://zhuanlan.zhihu.com/p/346205754?utm_medium=social&utm_oi=73844937195520 The
main logical function of the learning rate adjustment class is to calculate the learning rate of each epochparameter group, update the value optimizerin the corresponding parameter group lr, and apply it in optimizerGradient update of learnable parameters. The parent class of all learning rate adjustment strategy classes is torch.optim.lr_scheduler._LRScheduler, and the base class _LRSchedulerdefines the following methods:

  • step(epoch=None): Common to subclasses
  • get_lr(): Subclasses need to implement
  • get_last_lr(): Common to subclasses
  • print_lr(is_verbose, group, lr, epoch=None): display lr adjustment information
  • state_dict(): Subclasses may override
  • load_state_dict(state_dict): Subclasses may override

Initialize init :

class _LRScheduler(object):

    def __init__(self, optimizer, last_epoch=-1, verbose=False):
    
        .......
        
        self.optimizer = optimizer

        # Initialize epoch and base learning rates
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
        self.last_epoch = last_epoch
        
   		........
   		
        self.optimizer.step = with_counter(self.optimizer.step)
        self.optimizer._step_count = 0
        self._step_count = 0
        self.verbose = verbose

        self.step()

Initialization parameters:

  • optimizeris an instance of the optimizer
  • last_epochIt is the last epochtime index. The default value is -1, which represents the initial training model. At this time, optimizerthe initial learning rate will be set for each parameter group in it initial_lr.

If last_epochthe input value is greater than -1, it means to continue the last training from a epoch certain , and the required optimizerparameter group has initial_lrinitial learning rate information. with_counterThe function inside the initialization function is mainly to ensure lr_scheduler.step()that it is optimizer.step()called later. Note that it is called in the last step of the __init__ function self.step(), that is, the method _LRSchedulerhas been called once during initialization step().

step

When the model completes a epochtraining , it needs to call step()the method. After the self-increment in this method , call the method last_epochimplemented by the subclass in the internal context manager class to obtain the learning rate of each parameter group at this time, and update it to the attribute of , and finally record the last adjusted learning rate , this attribute will be returned in the method. In this method, the internal class of the context management function is used , and the instance object adds attributes, which can be used in subclasses.get_lr()epochoptimizerparam_groupsself._last_lrget_last_lr()_enable_get_lr_call_get_lr_called_within_step

def step(self, epoch=None):
   # Raise a warning if old pattern is detected
   # https://github.com/pytorch/pytorch/issues/20124
   if self._step_count == 1:
       if not hasattr(self.optimizer.step, "_with_counter"):
           warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                         "initialization. Please, make sure to call `optimizer.step()` before "
                         "`lr_scheduler.step()`. See more details at "
                         "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

       # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
       elif self.optimizer._step_count < 1:
           warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                         "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                         "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                         "will result in PyTorch skipping the first value of the learning rate schedule. "
                         "See more details at "
                         "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
   self._step_count += 1

   class _enable_get_lr_call:

       def __init__(self, o):
           self.o = o

       def __enter__(self):
           self.o._get_lr_called_within_step = True
           return self

       def __exit__(self, type, value, traceback):
           self.o._get_lr_called_within_step = False

   with _enable_get_lr_call(self):
       if epoch is None:
           self.last_epoch += 1
           values = self.get_lr()
       else:
           warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
           self.last_epoch = epoch
           if hasattr(self, "_get_closed_form_lr"):
               values = self._get_closed_form_lr()
           else:
               values = self.get_lr()

   for i, data in enumerate(zip(self.optimizer.param_groups, values)):
       param_group, lr = data
       param_group['lr'] = lr
       self.print_lr(self.verbose, i, lr, epoch)

   self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

get_lr

get_lr()The method is an abstract method, which defines the interface for updating the learning rate strategy. Different subclasses will have different implementations after inheritance. Its return value is [lr1, lr2, ...] structure

def get_lr(self):
    if not self._get_lr_called_within_step:
        warnings.warn("To get the last learning rate computed by the scheduler, "
                      "please use `get_last_lr()`.", UserWarning)

    if self.last_epoch == 0:
        return [group['lr'] for group in self.optimizer.param_groups]
    return [group['lr'] * self.gamma
            for group in self.optimizer.param_groups]

get_last_lr

def get_last_lr(self):
    """ Return last computed learning rate by current scheduler.
    """
    return self._last_lr

print_lr

print_lr(is_verbose, group, lr, epoch=None)): This method provides the function of displaying lr adjustment information

def print_lr(self, is_verbose, group, lr, epoch=None):
    """Display the current learning rate.
    """
    if is_verbose:
        if epoch is None:
            print('Adjusting learning rate'
                  ' of group {} to {:.4e}.'.format(group, lr))
        else:
            epoch_str = ("%.2f" if isinstance(epoch, float) else
                         "%.5d") % epoch
            print('Epoch {}: adjusting learning rate'
                  ' of group {} to {:.4e}.'.format(epoch_str, group, lr))

other interface

  • state_dict(): Return all attribute information of the current instance in the form of dictionary dict self.optimizerexcept
  • load_state_dict(state_dict): Reload previously saved state information

Guess you like

Origin blog.csdn.net/weixin_42486623/article/details/129917712