小白学Pytorch系列--Torch.optim API Scheduler(3)

小白学Pytorch系列–Torch.optim API Scheduler(3)


torch.optim.lr_scheduler提供了几种根据时期数量调整学习率的方法。
torch.optim.lr_scheduler.ReduceLROnPlateau 允许根据某些验证测量值降低动态学习率。
学习率调度应在优化器更新后应用;例如,你应该这样写你的代码

Demo:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

大多数学习率调度器可以称为背靠背调度器(也称为链式调度器)。结果是,每个调度器一个接一个地应用于前一个调度器获得的学习率。

Demo:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler1.step()
    scheduler2.step()

在文档的许多地方,我们将使用以下模板来引用调度器算法。

scheduler = ...
for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step()

scheduler 源码解析

参考:https://zhuanlan.zhihu.com/p/346205754?utm_medium=social&utm_oi=73844937195520
学习率调整类主要的逻辑功能就是每个 epoch 计算参数组的学习率,更新 optimizer对应参数组中的lr值,从而应用在optimizer里可学习参数的梯度更新。所有的学习率调整策略类的父类是torch.optim.lr_scheduler._LRScheduler,基类 _LRScheduler 定义了如下方法:

  • step(epoch=None): 子类公用
  • get_lr(): 子类需要实现
  • get_last_lr(): 子类公用
  • print_lr(is_verbose, group, lr, epoch=None): 显示 lr 调整信息
  • state_dict(): 子类可能会重写
  • load_state_dict(state_dict): 子类可能会重写

初始化 init:

class _LRScheduler(object):

    def __init__(self, optimizer, last_epoch=-1, verbose=False):
    
        .......
        
        self.optimizer = optimizer

        # Initialize epoch and base learning rates
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
        self.last_epoch = last_epoch
        
   		........
   		
        self.optimizer.step = with_counter(self.optimizer.step)
        self.optimizer._step_count = 0
        self._step_count = 0
        self.verbose = verbose

        self.step()

初始化参数:

  • optimizer就是优化器的实例
  • last_epoch是最后一次 epochindex,默认值是 -1,代表初次训练模型,此时会对optimizer里的各参数组设置初始学习率 initial_lr

last_epoch传入值大于 -1,则代表从某个 epoch 开始继续上次训练,此时要求optimizer的参数组中有initial_lr初始学习率信息。初始化函数内部的 with_counter 函数主要是为了确保lr_scheduler.step()是在optimizer.step()之后调用的. 注意在__init__函数最后一步调用了self.step(),即_LRScheduler在初始化时已经调用过一次step()方法。

step

当模型完成一个 epoch 训练时,需要调用step()方法,该方法里对last_epoch自增之后,在内部上下文管理器类里调用子类实现的get_lr()方法获得各参数组在此次 epoch 时的学习率,并更新到 optimizerparam_groups属性之中,最后记录下最后一次调整的学习率到self._last_lr,此属性将在get_last_lr()方法中返回。在这个方法中用到了上下文管理功能的内部类 _enable_get_lr_call,实例对象添加了_get_lr_called_within_step属性,这个属性可在子类中使用。

def step(self, epoch=None):
   # Raise a warning if old pattern is detected
   # https://github.com/pytorch/pytorch/issues/20124
   if self._step_count == 1:
       if not hasattr(self.optimizer.step, "_with_counter"):
           warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                         "initialization. Please, make sure to call `optimizer.step()` before "
                         "`lr_scheduler.step()`. See more details at "
                         "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

       # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
       elif self.optimizer._step_count < 1:
           warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                         "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                         "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                         "will result in PyTorch skipping the first value of the learning rate schedule. "
                         "See more details at "
                         "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
   self._step_count += 1

   class _enable_get_lr_call:

       def __init__(self, o):
           self.o = o

       def __enter__(self):
           self.o._get_lr_called_within_step = True
           return self

       def __exit__(self, type, value, traceback):
           self.o._get_lr_called_within_step = False

   with _enable_get_lr_call(self):
       if epoch is None:
           self.last_epoch += 1
           values = self.get_lr()
       else:
           warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
           self.last_epoch = epoch
           if hasattr(self, "_get_closed_form_lr"):
               values = self._get_closed_form_lr()
           else:
               values = self.get_lr()

   for i, data in enumerate(zip(self.optimizer.param_groups, values)):
       param_group, lr = data
       param_group['lr'] = lr
       self.print_lr(self.verbose, i, lr, epoch)

   self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

get_lr

get_lr() 方法是抽象方法,定义了更新学习率策略的接口,不同子类继承后会有不同的实现.其返回值是[lr1, lr2, …]结构

def get_lr(self):
    if not self._get_lr_called_within_step:
        warnings.warn("To get the last learning rate computed by the scheduler, "
                      "please use `get_last_lr()`.", UserWarning)

    if self.last_epoch == 0:
        return [group['lr'] for group in self.optimizer.param_groups]
    return [group['lr'] * self.gamma
            for group in self.optimizer.param_groups]

get_last_lr

def get_last_lr(self):
    """ Return last computed learning rate by current scheduler.
    """
    return self._last_lr

print_lr

print_lr(is_verbose, group, lr, epoch=None)): 该方法提供了显示 lr 调整信息的功能

def print_lr(self, is_verbose, group, lr, epoch=None):
    """Display the current learning rate.
    """
    if is_verbose:
        if epoch is None:
            print('Adjusting learning rate'
                  ' of group {} to {:.4e}.'.format(group, lr))
        else:
            epoch_str = ("%.2f" if isinstance(epoch, float) else
                         "%.5d") % epoch
            print('Epoch {}: adjusting learning rate'
                  ' of group {} to {:.4e}.'.format(epoch_str, group, lr))

其他接口

  • state_dict(): 以字典 dict 形式返回当前实例除 self.optimizer 之外的其他所有属性信息
  • load_state_dict(state_dict): 重新载入之前保存的状态信息

猜你喜欢

转载自blog.csdn.net/weixin_42486623/article/details/129917712
今日推荐