【Pytorch】在修改后的网络结构上加载Pre-trained模型以及Fine-tuning

在实际工作或者学习当中,为了节省时间提高效率,我们在深度学习训练中,一般会使用已经训练好的开源模型(一般都是基于ImageNet数据集),但通常情况下我们自己涉及的模型和别人训练好的有很多地方不一样。 难道我们就没法用了吗?当然不是,我们可以有很多种方法去实现我们想要的。


其实并不是为了学习,只是在等湖人打快船比赛


Pre-trained

目前共有三种加载Pre-trained模型的方法:

  • 第一种是修改网络最后的全连接层输出;
  • 第二种是选择性的加载模型的某些网络层;
  • 第三种是移植方法,直接训练好的网络模型移植到我们自己的网络模型当中。
#导入头文件
from torch import nn
import torch
from torchvision import models
from torch.autograd import Variable
from torch import optim

方法一

#改变最后输出类别数
transfer_model = models.resnet18(pretrained=True)

dim_in = transfer_model.fc.in_features
transfer_model.fc = nn.Linear(dim_in,10) #img_class =10
#print(transfer_model)

方法二

for param in transfer_model.parameters():
    param.requires_grad = False

optimizer = optim.SGD(transfer_model.fc.parameters(),lr=1e-3)
#为了加快效率,我们只在优化器中更新全连接部分中的参数

方法三

resnet50 = models.resnet50(pretrained=True)#加载model
cnn = CNN(Bottleneck, [3, 4, 6, 3])#自定义网络

#读取参数
pretrained_dict = resnet50.state_dict()
model_dict = cnn.state_dict()

# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 更新现有的model_dict
model_dict.update(pretrained_dict)

# 加载我们真正需要的state_dict
cnn.load_state_dict(model_dict)

# print(resnet50)
print(cnn)

Fine-tuning

何时以及如何Fine-tune

决定如何使用迁移学习的因素有很多,这是最重要的只有两个:新数据集的大小、以及新数据和原数据集的相似程度。有一点一定记住:网络前几层学到的是通用特征,后面几层学到的是与类别相关的特征。这里有使用的四个场景:

1.新数据集比较小且和原数据集相似。因为新数据集比较小,如果fine-tune可能会过拟合;又因为新旧数据集类似,我们期望他们高层特征类似,可以使用预训练网络当做特征提取器,用提取的特征训练线性分类器。

2.新数据集大且和原数据集相似。因为新数据集足够大,可以fine-tune整个网络。

3.新数据集小且和原数据集不相似。新数据集小,最好不要fine-tune,和原数据集不类似,最好也不使用高层特征。这时可是使用前面层的特征来训练SVM分类器。

扫描二维码关注公众号,回复: 9699636 查看本文章

4.新数据集大且和原数据集不相似。因为新数据集足够大,可以重新训练。但是实践中fine-tune预训练模型还是有益的。新数据集足够大,可以fine-tine整个网络。

发布了44 篇原创文章 · 获赞 9 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/Jeremy_lf/article/details/104744809