小白编程用Pytorch导入预训练模型&&设置不同学习速率

前两天正好在做这个部分,参考了很多网友的做法,也去pytorch论坛查了一下,现在总结如下。建议还是自己单步调试一下看看每个参数里面的值是什么样的比较好。

1.导入预训练的模型,预训练模型是现有模型的一个或者几个部分

假设我有一个网络包含 pretrained和classify两个部分,每个部分分别有一些卷积层or回归层,pretrained部分有一个已经训练好的网络模型pretrained model,那么我需要把这个网络模型导入到现有的网络中,实现代码如下:

# load pretrained model
pretrained_net = PretrainedNet()
pretrained_net.load_state_dict(torch.load('epochs/epoch_3_100.pt'))
pretrained_dict = pretrained_net.state_dict() 

# prepare my model_dict
model = MyNet()
model_dict = model.state_dict()

# trained_part 是pretrained部分在现有网络的名称,个人喜好把网络标记的明确一些,所以pretrained的部分都会写成一个pretrained类,然后调用,这样这个子块的每个参数名称就变成pretrained.xx.weight这样,有的人喜欢直接把pretrained部分写成跟pretrained model一样的参数名称,都OK,去掉pretrained_part这个变量就好。
pretrained_part = 'pretrained.'
dict_temp = {pretrained_part + k: v for k, v in pretrained_dict.items() if pretrained_part + k in model_dict}
model_dict.update(dict_temp)
model.load_state_dict(model_dict)

这部分一般是在train的时候,实例化训练网络之后,导入预训练模型。导入模型后,有时候需要固定这部分的参数或者给他们一个很低的学习速率,这时候就要开始给设置不同的学习速率。

2.给不同的部分设置学习速率。

固定pretrained的参数,仅仅训练classify

# setting the pretrained part leaning rate as zero, Only train the classifier part

for param in list(model.pretrained.parameters()):
    param.requires_grad = False

params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(params, lr=1e-3)

分别给pretrained和classify部分设置不同的学习速率

# pretrained_params 是给filter的一个list,用来过滤,其中的值是int,所以在optimizer设置参数的时候,不能用pretrained_params,应当直接使用model.pretrained.parameters()

pretrained_params = list(map(id, model.pretrained.parameters()))
classify_params = filter(lambda p: id(p) not in pretrained_params, model.parameters())

optimizer = optim.Adam([{'params': classify_params},
            {'params': model.pretrained.parameters(), 'lr': 1e-5}], lr=1e-3)

#设置学习速率的step方式
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) 
model.children_model.paramters()可以直接遍历子网络里面的parameter,不用把每个卷积层的parameter都列出来。
pytorch还在摸索中,自己实现个代码对编程小白来说还是比较轻松的,现在使用到什么学什么,勉强够用吧,版本0.3.1。因为听说0.4有大改动,担心之前的代码不能用,暂时不升级了。

猜你喜欢

转载自blog.csdn.net/elysion122/article/details/80291975
今日推荐