pytorch torch.optim.lr_scheduler 调整学习率的六种策略
1. 为什么需要调整学习率
在深度学习训练过程中,最重要的参数就是学习率,通常来说,在整个训练过层中,学习率不会一直保持不变,为了让模型能够在训练初期快速收敛,学习率通常比较大,在训练末期,为了让模型收敛在更小的局部最优点,学习率通常要比较小。
2. 学习率的初始值设置
其实,不同的任务学习率的初始值是需要试验几次来获得的,使用的优化器不同,mini-batch 的 batch_size 大小不同,学习率的初始值也不太相同。
根据我的实验经验,同一个任务,如果使用Adam
优化器,学习率初始值 0.001 比较好,Adam优化器对于初始值实际上不太敏感,基本都能做到快速收敛;如果使用SGD
优化器的话,学习率需要在Adam的基础上乘以10倍或者100倍,也就是使用0.1或者0.01比较好。
需要注意的一点,通常 batch_size 扩大 n 倍,学习率也要相应扩大 n \sqrt n n,虽然这么说,我自己实验的时候,没发现太多作用,可能是因为使用Adam
的原因。
3. torch 学习率的调整策略
pytorch 提供了一些基础的学习率调整策略,在 torch.optim.lr_scheduler
模块里,下方的代码实现了5种可能用到(其实我自己用的都是最基础的阶梯下降)的学习率调整方法,图片可视化了学习率变小的过程,非常简单,对照图片和代码仔细看看就明白每个scheduler的具体参数含义了:
import torch
import matplotlib.pyplot as plt
lr = 0.001
# 20代表从lr从最大到最小的epoch数,0代表学习率的最小值
scheduler_cos = torch.optim.lr_scheduler.CosineAnnealingLR(torch.optim.SGD([torch.ones(1)], lr), 20, 0)
# 20 和 0.5 代表每走20个epoch,学习率衰减0.5倍,阶梯形式
scheduler_step = torch.optim.lr_scheduler.StepLR(torch.optim.SGD([torch.ones(1)], lr), 20, 0.5)
# 每走一个epoch,学习率衰减0.95倍
scheduler_exp = torch.optim.lr_scheduler.ExponentialLR(torch.optim.SGD([torch.ones(1)], lr), 0.95)
# 三角的形式,0.0001代表最小的学习率, 0.001代表最大的学习率, 20代表一个升降周期
scheduler_cyc = torch.optim.lr_scheduler.CyclicLR(torch.optim.SGD([torch.ones(1)], lr), 0.0001, 0.001, 20)
# 阶梯衰减,每次衰减的epoch数根据列表 [20, 30, 60, 80] 给出,0.8代表学习率衰减倍数
scheduler_mul = torch.optim.lr_scheduler.MultiStepLR(torch.optim.SGD([torch.ones(1)], lr), [20, 30, 60, 80], 0.8)
lr_cos = []
lr_step = []
lr_exp = []
lr_cyc = []
lr_mul = []
for i in range(100):
lr_cos += scheduler_cos.get_last_lr()
lr_step += scheduler_step.get_last_lr()
lr_exp += scheduler_exp.get_last_lr()
lr_cyc += scheduler_cyc.get_last_lr()
lr_mul += scheduler_mul.get_last_lr()
scheduler_cos.step()
scheduler_step.step()
scheduler_exp.step()
scheduler_cyc.step()
scheduler_mul.step()
plt.figure(figsize=(12,7))
plt.plot(list(range(len(lr_cos))), lr_cos,
list(range(len(lr_step))), lr_step,
list(range(len(lr_exp))), lr_exp,
list(range(len(lr_cyc))), lr_cyc,
list(range(len(lr_mul))), lr_mul,)
plt.legend(['cos','step','exp','cyc', 'mul'], fontsize=20)
plt.xlabel('epoch', size=15)
plt.ylabel('lr', size=15)
plt.show()
运行上面代码后的效果: