Pytorch 不训练(frozen)一些神经网络层的方法

Pytorch 不训练(frozen)一些神经网络层的方法

我们在做深度学习的时候经常会使用预训练的模型。很多情况下,加载进来模型是为了完成其他任务,在这种情况下,加载模型的一部分是不需要再训练的。那么我们就需要forozen这些神经网络层。

固定某些层训练,就是将tensor的requires_grad设为False。
此外,一定要记住,我们还需要在optim优化器中再将这些参数过滤掉!
下面见代码:

device = torch.device("cuda" )

    #Try to load models

model = DGCNN(args)
print(str(model))
model = model.to(device)


    
save_model = torch.load('model.t7')
model_dict =  model.state_dict()

#更新模型的参数,因为自己的网络比pretrain的模型更复杂
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys())  # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

for name,p in model.named_parameters():
    if name.startswith('conv1'): p.requires_grad = False
    if name.startswith('conv2'): p.requires_grad = False
    if name.startswith('conv3'): p.requires_grad = False
    if name.startswith('conv4'): p.requires_grad = False
    if name.startswith('bn1'): p.requires_grad = False
    if name.startswith('bn2'): p.requires_grad = False
    if name.startswith('bn3'): p.requires_grad = False
    if name.startswith('bn4'): p.requires_grad = False
    
opt = optim.SGD(filter(lambda x: x.requires_grad is not False ,model.parameters()), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)

Reference

https://blog.csdn.net/qq_34914551/article/details/87699317

发布了131 篇原创文章 · 获赞 6 · 访问量 6919

猜你喜欢

转载自blog.csdn.net/Orientliu96/article/details/104705912