PyTorch - 保存和加载模型

保存和加载模型

关于保存和加载模型,有三个核心功能需要熟悉:
torch.save:将序列化的对象保存到磁盘。此函数使用Python的 pickle实用程序进行序列化。使用此功能可以保存各种对象的模型,张量和字典。
torch.load:使用pickle的解腌功能将腌制的目标文件反序列化到内存中。此功能还有助于设备将数据加载到其中。
torch.nn.Module.load_state_dict:使用反序列化的state_dict加载模型的参数字典 。

state_dict

字典,具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm的running_mean)在模型的state_dict中具有条目。优化器对象(torch.optim)也有state_dict,其中包含有关优化器状态以及所用超参数的信息。

保存和加载参数

torch.save(model.state_dict(), PATH)

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

保存模型以进行推理时,仅需要保存训练后的模型的学习参数。使用 torch.save() 函数保存模型的state_dict将为您提供最大的灵活性,以便以后恢复模型,这就是为什么推荐使用此方法来保存模型。

保存和加载模型

torch.save(model, PATH)

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

不保存类,之前要有定义

节点保存

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    ...
    }, PATH)

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

保存多个模型

torch.save({
    'modelA_state_dict': modelA.state_dict(),
    'modelB_state_dict': modelB.state_dict(),
    'optimizerA_state_dict': optimizerA.state_dict(),
    'optimizerB_state_dict': optimizerB.state_dict(),
    ...
    }, PATH)

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

模型间参数迁移

torch.save(modelA.state_dict(), PATH)

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

strict=False忽略不匹配,也可以更改键名进行匹配

GPU、CPU

# G-C
torch.save(model.state_dict(), PATH)

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

# G-G
torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

# C-C 之前正常操作
# C-G
torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

分布式

DataParallel
torch.save(model.module.state_dict(), PATH)

Pytorch 原理

pathlib

https://www.jianshu.com/p/ae194371cf7c

from pathlib import Path
BASE_DIR = Path(__file__).resolve().parent.parent#本文件路径变Path对象,取绝对路径,上一级,上一级
TEMPLATES_DIR = BASE_DIR.joinpath('templates')
#对应
import os.path

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEMPLATES_DIR = os.path.join(BASE_DIR, 'templates')
import os
import os.path

os.makedirs(os.path.join('src', '__pypackages__'), exist_ok=True)
os.rename('.editorconfig', os.path.join('src', '.editorconfig'))
#对应
from pathlib import Path

Path('src/__pypackages__').mkdir(parents=True, exist_ok=True)
Path('.editorconfig').rename('src/.editorconfig')
from glob import glob

top_level_csv_files = glob('*.csv')
all_csv_files = glob('**/*.csv', recursive=True)
#对应
from pathlib import Path

top_level_csv_files = Path.cwd().glob('*.csv')
all_csv_files = Path.cwd().rglob('*.csv')

下载数据集

from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

从头建立神经网络

以Xavier初始化方法(每个元素都除以1/sqrt(n))为例来对权重进行初始化
尾缀为_的方法在PyTorch中表示这个操作会被立即被执行,原地操作

loss.backward()
        with torch.no_grad():
            weights -= weights.grad * lr
            bias -= bias.grad * lr
            weights.grad.zero_()
            bias.grad.zero_()

可以使用标准python调试器对PyTorch代码进行单步调试,从而在每一步检查不同的变量值。
from IPython.core.debugger import set_trace
进入循环时使用

各个模块

nn.functional:包含了torch.nn库中的所有函数(这个库的其它部分是各种类),激活函数,损失函数等
nn.Module:一个能够跟踪状态的类,用来被继承写网络
nn.Parameter:构建权重和偏置
nn.Linear:定义好的网络层
nn.Sequential:序列化运行它包含的模块。这是一个更简单的搭建神经网络的方式。
自定义层:增加一个view层

class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


def preprocess(x):
    return x.view(-1, 1, 28, 28)

猜你喜欢

转载自blog.csdn.net/ljyljyok/article/details/107239400