【DeepLearning】【PyTorch (6)】Getting Started - Saving and Loading Models

PyTorch 官方教程 Getting Started 第六部分 Saving and Loading Models 笔记.

1. What is a state_dict?

在 PyTorch 中,使用 model.parameters()model.state_dict() 保存 model(torch.nn.Module 类)模型的参数(权重系数和偏置系数等)。使用 optimizer.state_dict() 保存 model 的超参数(hyperparameters),或者说是保存模型的优化器 optimizer (torch.optim 类)的信息。


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

打印模型的 state_dict

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print()
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

out:

Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140097913969256, 140097913807280, 140097913807208, 140097913807352, 140097913807424, 140097913807496, 140097913807568, 140097913807640, 140097913807712, 140097913807784]}]

打印 state_dict 的类型

print(type(model.state_dict()))
print(type(optimizer.state_dict()))

out:

<class 'collections.OrderedDict'>
<class 'dict'>

打印 model.state_dict() 的键

print(model.state_dict().keys())

out:

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

打印 optimizer.state_dict() 的键

print(optimizer.state_dict().keys())

out:

dict_keys(['state', 'param_groups'])

打印 optimizer.state_dict() 键为 ‘state’ 的值

print(type(optimizer.state_dict()['state']))
print(optimizer.state_dict()['state'])

out:

<class 'dict'>
{}

探索 optimizer.state_dict() 键为 ‘param_groups’ 的值

print(type(optimizer.state_dict()['param_groups']))
print(len(optimizer.state_dict()['param_groups']))
print(type(optimizer.state_dict()['param_groups'][0]))
print(optimizer.state_dict()['param_groups'][0].keys())

out:

<class 'list'>     # optimizer.state_dict()['param_groups'] 是一个列表
1                  # optimizer.state_dict()['param_groups'] 列表长度为1
<class 'dict'>     # 列表中的元素为一个字典
dict_keys(['lr', 'momentum', 'dampening', 'weight_decay', 'nesterov', 'params'])         # 这个字典存储了学习率(lr)等超参数   

model.parameters() 是一个生成器

print(model.parameters())
print(type(model.parameters()))

out:

<generator object Module.parameters at 0x7f4b5b31f0a0>
<class 'generator'>

打印 model.parameters() 中 2 个元素

for index, param_tensor in  enumerate(model.parameters()):
    print(type(param_tensor))
    print(param_tensor.size())
    if index == 1:
        break

out:

<class 'torch.nn.parameter.Parameter'>
torch.Size([6, 3, 5, 5])
<class 'torch.nn.parameter.Parameter'>
torch.Size([6])

model.state_dict() 提取 conv1.weight 和 conv1.bias

conv1_weight_d = model.state_dict()['conv1.weight']
conv1_bias_d = model.state_dict()['conv1.bias']
print(conv1_weight_d.size())
print(conv1_bias_d.size())

out:

torch.Size([6, 3, 5, 5])
torch.Size([6]

model.parameters() 提取 conv1.weight 和 conv1.bias

parameters = [x for x in model.parameters()]
conv1_weight_p = parameters[0]
conv1_bias_p = parameters[1]
print(conv1_weight_p.size())
print(conv1_bias_p.size())

out:

torch.Size([6, 3, 5, 5])
torch.Size([6]

model.state_dict() 和从 model.parameters() 提取的 conv1.weight 和 conv1.bias 是等价的

print(torch.equal(conv1_weight_p, conv1_weight_d))
print(torch.equal(conv1_bias_p, conv1_bias_d))

out:

True
True

2. Saving & Loading Model for Inference

2.1 Save/Load state_dict (Recommended)

Save:

PATH = 'myModel.pt'
torch.save(model.state_dict(), PATH)

Load:

如果 model 没有先声明,load model 将报错

model_ft.load_state_dict(torch.load(PATH)
NameError                                 Traceback (most recent call last)
<ipython-input-45-9ecdb1109f74> in <module>
----> 1 model_ft.load_state_dict(torch.load(PATH))

NameError: name 'model_ft' is not defined
model_ft = TheModelClass()
model_ft.load_state_dict(torch.load(PATH))
model_ft.eval()

out:

TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

2.2 Save/Load Entire Model

3. Saving & Loading a General Checkpoint

4. Saving Multiple Models in One File

5. Warmstarting Model Using Parameters from a Different Model

6. Saving & Loading Model Across Devices

猜你喜欢

转载自blog.csdn.net/RadiantJeral/article/details/86653041