if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
num_fc_ftr = model.fc.in_features ##重写分类输出层
model.fc = torch.nn.Linear(num_fc_ftr, num_classes)
###############加载自己训练的模型
print ("加载res34自己训练的模型")
pretrained_dict = torch.load("day_label_smooth_resnet34_model_best.pth.tar")["state_dict"]
for k,v in model.state_dict().items():
print (k)
print (v.shape)
pretrained_dict.pop('fc.weight') # 加载的参数直接删除全连接层的参数,打印名字看看
pretrained_dict.pop('fc.bias')
model.load_state_dict(pretrained_dict, strict=False)
pytorch加载预训练模型或者自己的模型,并修改最后的分类输出层strict=False
猜你喜欢
转载自blog.csdn.net/m0_37192554/article/details/103185094
今日推荐
周排行