pytorch加载预训练模型或者自己的模型,并修改最后的分类输出层strict=False

        if args.pretrained:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
            num_fc_ftr = model.fc.in_features ##重写分类输出层
            model.fc = torch.nn.Linear(num_fc_ftr, num_classes)
            ###############加载自己训练的模型
            print ("加载res34自己训练的模型")
            pretrained_dict = torch.load("day_label_smooth_resnet34_model_best.pth.tar")["state_dict"]
            
            for k,v in model.state_dict().items():
                print (k)
                print (v.shape)
            pretrained_dict.pop('fc.weight') # 加载的参数直接删除全连接层的参数,打印名字看看
            pretrained_dict.pop('fc.bias')
            model.load_state_dict(pretrained_dict, strict=False)
发布了98 篇原创文章 · 获赞 141 · 访问量 26万+

猜你喜欢

转载自blog.csdn.net/m0_37192554/article/details/103185094