迁移学习-模型参数加载

1. 模型定义时name规则

  • 定义了变量名的,name=变量名;
  • 没有定义变量名的,使用Sequential()的,从0开始标号name。
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)  # conv1.*
        self.layer1=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3),      # layer1.0.*
            nn.BatchNorm2d(64),                                            # layer1.1.*
        )
        self.layer2=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3),     # layer2.0.*

            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3),      # layer2.1.*

        )
        self.lay=self._make_layer()                                         #lay.0.* + lay.1.*
    def forward(self, input):
        out=self.conv1(input)
        out=self.lay(out)
        return out
    def _make_layer(self):

        self.layers = []
        self.layers.append(
            nn.Conv2d(12,12,4)
        )
        self.layers.append(nn.Conv2d(12, 12, 4))
        return nn.Sequential(*self.layers)

pars=Net().state_dict()

2. 训练参数迁移

# 你的模型
net=model()
# 训练好的模型参数读取
pre_dict=torch.load('c3d.pickle')  
# 你的模型参数, 即初始化参数
model_dict=net.state_dict() 
# 将pretrained_dict里不属于model_dict的键剔除掉 
pre_dict =  {name: value for name, value in pre_dict.items() if name in model_dict} 
# 更新现有的model_dict 
model_dict.update(pre_dict) 
# 加载我们真正需要的state_dict 
net.load_state_dict(model_dict) 
发布了31 篇原创文章 · 获赞 8 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/weijie_home/article/details/104781728
今日推荐