PyTorch 保存和加载模型参数,从预训练模型中加载部分参数,包括预训练模型中某些参数不匹配的情况

0 前言

  这篇博客主要是对使用PyTorch保存和加载训练模型参数的一个学习记录。第1-5小节是比较常规的模型参数保存操作,第6小是用已经训练好的模型参数来初始化新的模型,包括从一层加载到另一层,某些参数名不匹配的情况,也给出了实验代码和结果完整实验项目见github如果对您有所帮助,欢迎关注点赞~

1 state_dict

  在PyTorch中,torch.nn.Module的可学习参数(i.e. weights and biases),保存在模型的parameters中,它可以通过model.parameters()进行访问。state_dict是一个从参数名称映射到参数Tensor的字典对象。注意,只有具有可学习参数的层(卷积层、线性层等)和已经注册的缓冲区(bachnorm’s running _mean)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。由于state_dic对象是Python字典,因此可以轻松地保存、更新、更改和还原它们,从而为PyTorch模型和优化增加了很多模块性。
训练分类器教程中使用的简单模型看一下state_dict。

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

从输出结果可以看出,每一层的模型参数的名称格式是:层名.参数;如果有它的一层是由另一个类定义的话,那么就把层名往后扩展:层名.层名…参数。下面对上述代码的模型进行重新整理,验证一下。

TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

模型重新整理的代码与结果:

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc_the_model_class = FC()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc_the_model_class(x)
        return x


class FC(nn.Module):
    def __init__(self):
        super(FC, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        return self.fc(x)


# Initialize model
model = TheModelClass()

print(model)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

  参数名称输出。整个模型的网络结构还是一样的,但是将全连接层重新使用FC类来定义了。从输出的网络结构可以看出,TheModelClass类中定义时,使用的类的结构层次,在网络的结构中会体现。与(conv2)并列的(fc_the_model_class)是在TheModelClass类定义时用的变量名。后接的FC是fc_the_model_class使用的类名,后面的是这个类中定义的层。输出模型时,就是按一种深度优先的方法遍历了整个模型。对于更深层次层的参数,类名是不会出现在参数名中的,然后将参数名按深度组织:fc_the_model_class.fc.0.weight,也就是在打印过程中,:后面的名称会忽略。

TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc_the_model_class): FC(
    (fc): Sequential(
      (0): Linear(in_features=400, out_features=120, bias=True)
      (1): ReLU()
      (2): Linear(in_features=120, out_features=84, bias=True)
      (3): ReLU()
      (4): Linear(in_features=84, out_features=10, bias=True)
    )
  )
)
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc_the_model_class.fc.0.weight 	 torch.Size([120, 400])
fc_the_model_class.fc.0.bias 	 torch.Size([120])
fc_the_model_class.fc.2.weight 	 torch.Size([84, 120])
fc_the_model_class.fc.2.bias 	 torch.Size([84])
fc_the_model_class.fc.4.weight 	 torch.Size([10, 84])
fc_the_model_class.fc.4.bias 	 torch.Size([10])

2 保存和加载用于推理的模型参数

保存使用:

torch.save(model.state_dict(), PATH)

加载使用:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

保存模型用于推理时,仅需要保存训练后的模型的参数,使用torch.save()函数直接保存模型的state_dict,通常文件的后缀名为.pt或.pth。请记住,在运行推理之前,必须先调用model.eval(),将dropout层和batch normalization层设为关闭状态。否则将会产生不一致的推断结果。
需要注意的是,load_state_dict()函数使用的是字典对象,而不是保存对象的路径,所以需要先进行torch.load(PATH)

3 保存和加载整个模型

保存使用:

torch.save(model, PATH)

加载使用:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

4 保存和加载用于推理或者继续训练的general checkpoing

保存使用:

torch.save({
    
    
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载使用:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

  在保存checkpoint用于继续训练时,保存优化器的state_dict是必要的,因为它包含着随着模型训练而更新的缓冲区和参数,可能也需要保存一些其他的项目,包括epoch和loss。常见的PyTorch约定是使用.tar文件扩展名保存这些检查点。

5 将多个模型参数保存在一个文件中

保存使用 本质上还是保存的是一个字典对象,PyTorch约定使用.tar保存这些检查点。

torch.save({
    
    
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

加载使用 加载还是加载的是字典对象,然后取字典对象。

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

6 使用来自不同模型的参数进行 Warmstarting Model ★ \bigstar

保存使用:

torch.save(modelA.state_dict(), PATH)

加载使用:对于不同的模型,设置strict=False是必要的。

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

  在迁移学习或者训练新的复杂模型时,部分加载模型或加载部分模型是常见的方案。利用训练过的参数,即使是只有一小部分可以使用,也会对warmstart训练过程有所帮助,而且有望比从头开始训练模型更快地收敛。所谓的warmstart,我理解的就是在参数初始化时,将待训练模型的参数使用已经训练好的模型的部分参数进行初始化,然后接着训练,这种参数初始化方案会大大提高收敛的速度。
  无论是从缺少某些键的部分state_dict加载,还是要加载比待加载模型更多的键的state_dic,都可以在lod_state_dict()中将strict参数设置为False,这样可以忽略不匹配的键。
   如果要将参数从一层加载到另一层,但是某些键不匹配,只需要加载的state_dict中参数键的名称,来匹配到要加载到的模型中的键。 实验代码如下。

targetModel = TheModelClass()

cifar_net = torch.load('./cifar_net.pth')

for item in cifar_net:
    print('cifar_net \t', item, '\t')

targetModel.load_state_dict(cifar_net, strict=False)

for item in targetModel.state_dict():
    print('targetModel \t', item, '\t')

print('cifar_net \t', cifar_net["fc3.bias"], '\t', cifar_net["fc3.bias"].data)
print('targetModel \t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"], '\t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"].data)

# 更新层的名称
cifar_net["fc_the_model_class.fc.0.weight"] = cifar_net.pop("fc1.weight")
cifar_net["fc_the_model_class.fc.0.bias"] = cifar_net.pop("fc1.bias")
cifar_net["fc_the_model_class.fc.2.weight"] = cifar_net.pop("fc2.weight")
cifar_net["fc_the_model_class.fc.2.bias"] = cifar_net.pop("fc2.bias")
cifar_net["fc_the_model_class.fc.4.weight"] = cifar_net.pop("fc3.weight")
cifar_net["fc_the_model_class.fc.4.bias"] = cifar_net.pop("fc3.bias")

targetModel.load_state_dict(cifar_net, strict=False)
print('cifar_net \t', cifar_net["fc_the_model_class.fc.4.bias"], '\t', cifar_net["fc_the_model_class.fc.4.bias"].data)
print('targetModel \t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"], '\t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"].data)

输出结果,可以看出fc_the_model_class.fc.4.bias的参数由随机初始化,变成从cifar_net模型中初始化。

cifar_net 	 tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04,  4.6838e-01,  1.1072e+00,
        -2.2960e-01,  1.9044e-01, -5.1352e-02,  1.8365e-01, -3.4669e-01],
       device='cuda:0') 	 tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04,  4.6838e-01,  1.1072e+00,
        -2.2960e-01,  1.9044e-01, -5.1352e-02,  1.8365e-01, -3.4669e-01],
       device='cuda:0')
targetModel 	 tensor([-0.0878, -0.1059, -0.0949,  0.0353,  0.0164, -0.1002, -0.0126, -0.1012,
        -0.0115, -0.1006]) 	 tensor([-0.0878, -0.1059, -0.0949,  0.0353,  0.0164, -0.1002, -0.0126, -0.1012,
        -0.0115, -0.1006])
cifar_net 	 tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04,  4.6838e-01,  1.1072e+00,
        -2.2960e-01,  1.9044e-01, -5.1352e-02,  1.8365e-01, -3.4669e-01],
       device='cuda:0') 	 tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04,  4.6838e-01,  1.1072e+00,
        -2.2960e-01,  1.9044e-01, -5.1352e-02,  1.8365e-01, -3.4669e-01],
       device='cuda:0')
targetModel 	 tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04,  4.6838e-01,  1.1072e+00,
        -2.2960e-01,  1.9044e-01, -5.1352e-02,  1.8365e-01, -3.4669e-01]) 	 tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04,  4.6838e-01,  1.1072e+00,
        -2.2960e-01,  1.9044e-01, -5.1352e-02,  1.8365e-01, -3.4669e-01])

参考资料

dict={
    
    'a':1, 'b':2}
dict["c"] = dict.pop("a")

猜你喜欢

转载自blog.csdn.net/PAN_Andy/article/details/103054958