从零学习PyTorch 第7课 模型Finetune与预训练模型


课程目录(在更新,喜欢加个关注点个赞呗):
从零学习pytorch 第1课 搭建一个超简单的网络
从零学习pytorch 第1.5课 训练集、验证集和测试集的作用
从零学习pytorch 第2课 Dataset类
从零学习pytorch 第3课 DataLoader类运行过程
从零学习pytorch 第4课 初见transforms
从零学习pytorch 第5课 PyTorch模型搭建三要素
从零学习pytorch 第5.5课 Resnet34为例学习nn.Sequential和模型定义
从零学习PyTorch 第6课 权值初始化
从零学习PyTorch 第7课 模型Finetune与预训练模型
从零学习PyTorch 第8课 PyTorch优化器基类Optimier

这一章比较有意思

上一课,介绍了模型的权值初始化,以及PyTorch自带的权值初始化方法函数。我们知道一个而良好的权值初始化,可以使收敛速度加快,甚至收获更好的精度。但是实际应用中,并不是如此,我们通常采用一个已经训练的模型的权值参数作为我们模型的初始化参数,这个就是Finetune,更宽泛的说,就是迁移学习!! 迁移学习中的Finetune技术,本质上就是让我们新构建的模型,拥有一个较好的权值初始值。

finetune权值初始化分三步:

  1. 保存模型,拥有一个预训练模型
  2. 加载模型,吧预训练模型中的权值中取出来
  3. 初始化,将权值对应的放在新模型中。

Finetune之权值初始化

在进行finetune之前,我们呢需要拥有一个模型或者模型参数,因此我们要学习如何保存模型。官方文档中介绍了两种保存模型的方法:

  1. 保存整个模型
  2. 保存模型参数(官方推荐这个)

保存模型参数

我们现有一个Net模型,就像前面几课讲得那样

net = Net()
torch.save(net.state_dict(),'net_params.pkl')

加载模型

这里只是加载模型的参数,就是上面那个玩意

pretrained_dict = torch.load('net_params.pkl')

初始化

放权值放到新的模型中:
首先我们创建新的模型,然后获取新模型的参数字典net_state_dict:

net = Net()
net_state_dict = net.state_dict()
# 接着将pretrain_dict中不属于net_state_dict的键剔除掉
pretrained_dict_1 = {k:v for k,v in pretrained_dict.items() if k in net_state_dict}
# 然后用与训练的参数字典,对新模型的参数字典net_state_dict进行更新
net_state_dict.update(pretrained_dict_1)
# 将更新的参数字典放回网络中
net.load_state_dict(net_state_dict)

这样,利用预训练模型参数对新模型的权值进行初始化的过程就算做完了

不同层不同学习率

在利用pre-trained model的参数做初始化之后,我们可能想让fc曾更新的相对快一点,而希望前面的权值更新速度慢一点,这就可以通过为不同的层设置不同的学习率来达到此目的。

为不同层设置不同的学习率,主要是通过优化器的多个函数设置不同的参数,所以,只要将原来的参数组,划分成两个甚至更多的参数组,然后分别设置学习率。

不多说,上案例,这里将原始参数划分成fc3层和其他参数,为fc3设置更大的学习率

ignore_params = list(map(id,net.fc3.parameters()))
base_params = filter(lambda p:id(p) not in ignored_params,net.parameters())
# 这里的ignore_params是fc3的参数,base_params是除了fc3层之外的参数

optimizer = optim.SGD([
	{'params':base_[arams},
	{'params':net.fc3.parameters(),'lr':0.001*1-}],0.001,momentum=0.9,weight_decay = le-4)
  • 第一行第二行的意思,就是把fc3层的参数net.fc3.parameters()从原始参数中net.Parameters()中剥离出来
  • optimizer = optim.SGD(…)这里的意思就是base_params中的曾用0.001,momentu=0.9.weight_decay=;e-4
  • 而fc3层设置的学习率为0.001*10

补充:这里面好像是根据内存地址是否重复,来排除掉fc3中的参数,得到base_params的。

发布了78 篇原创文章 · 获赞 14 · 访问量 9731

猜你喜欢

转载自blog.csdn.net/qq_34107425/article/details/104104287