【机器学习实战】PyTorch L2 正则化

论文 Bag of Tricks for Image Classification with Convolutional Neural Networks. 中提到,加 L2 正则就相当于将该权重趋向 0,而对于 CNN 而言,一般只对卷积层和全联接层的 weights 进行 L2(weight decay),而不对 biases 进行。Batch Normalization 层也不进行 L2。

PyTorch,只对卷积层和全联接层的 weights 进行 L2(weight decay):

weight_decay_list = (param for name, param in model.named_parameters() if name[-4:] != 'bias' and "bn" not in name)
no_decay_list = (param for name, param in model.named_parameters() if name[-4:] == 'bias' or "bn" in name)
parameters = [{'params': weight_decay_list},
              {'params': no_decay_list, 'weight_decay': 0.}]

optimizer = torch.optim.SGD(parameters, lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)

References

[1] He, T., Zhang, Z., Zhang, H., Zhang, Z., Xie, J., Li, M. (2019). Bag of Tricks for Image Classification with Convolutional Neural Networks. (CVPR) https://dx.doi.org/10.1109/cvpr.2019.00065

猜你喜欢

转载自www.cnblogs.com/wuliytTaotao/p/13192798.html