pytorch下分类神经网络的迁移学习transfer learning

对预训练模型的迁移引用【1】中的提法,分为两种形式

  1. 只训练最后fc层的freeze and train
  2. 以预训练模型为初始参数,训练所有层的finetune
这里只讨论网络结构的变更

finetune
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features #最后fc层的输入
model_ft.fc = nn.Linear(num_ftrs, NUM_CLASSES) #NUM_CLASSES是自己数据的类别

model_ft = models.vgg16(pretrained=True)
num_ftrs = model_ft.classifier[6].in_features
feature_model = list(model_ft.classifier.children())
feature_model.pop()            
feature_model.append(nn.Linear(num_ftrs, NUM_CLASSES))
model_ft.classifier = nn.Sequential(*feature_model)
如果在基础网络的基础上还要再增加层数,可用【2】中mian.py的方法
num_ftrs = model_ft.fc.in_features
feature_model = list(model_ft.fc.children())
feature_model.append(nn.Linear(num_ftrs, cf.feature_size))
feature_model.append(nn.BatchNorm1d(cf.feature_size))
feature_model.append(nn.ReLU(inplace=True))
feature_model.append(nn.Linear(cf.feature_size, len(dset_classes)))
model_ft.fc = nn.Sequential(*feature_model)
【2】中还提到了,特征提取的方法
if(args.net_type == 'alexnet' or args.net_type == 'vggnet'):
    feature_map = list(checkpoint['model'].module.classifier.children())
    feature_map.pop()
    new_classifier = nn.Sequential(*feature_map)
    extractor = copy.deepcopy(checkpoint['model'])
    extractor.module.classifier = new_classifier
elif (args.net_type) == 'resnet'):
    feature_map = list(model.module.children())
    feature_map.pop()
    extractor = nn.Sequential(*feature_map)

freeze_train 的网络结构在前面的基础上加入
for param in model_conv.parameters(): #params have requires_grad=True by default
        param.requires_grad = False
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, num_class)
以防止在反向传播的过程中,改变前面层的参数

【3】中说明了一下,随着训练的进行,learning_rate应该进行一定的衰减,以免在梯度下降过程中,在接近的时候local optimum的时候错过。
def lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7):
    """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
    lr = init_lr * (0.1**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer

【1】

猜你喜欢

转载自blog.csdn.net/lxx516/article/details/79019931