动手学深度学习5.4 PyTorch教程 读写文件

参与11月更文挑战的第21天,活动详情查看:2021最后一次更文挑战

读写文件有什么必要呢?

读写文件其实不是读取数据集。

是当你的训练时要定期存储中间结果,以确保在服务器电源不小心被断掉,或者出现其他情况的时候,损失掉你前几天的计算结果。

这一节要做的就是如何存储权重向量和整个模型。

import torch
from torch import nn
from torch.nn import functional as F
复制代码

loadsave

对于单个张量,我们可以直接调用loadsave函数分别读写它们。

  • torch.saves

    torch.save(obj, f, pickle_module=<module 'pickle' from '.../pickle.py'>, pickle_protocol=2)
    复制代码

    参数:

    • obj – 保存对象
    • f - 字符串,文件名
    • pickle_module – 用于pickling元数据和对象的模块
    • pickle_protocol – 指定pickle protocal 可以覆盖默认参数
  • torch.load

    torch.load(f, map_location=None, pickle_module=<module 'pickle' from '.../pickle.py'>)
    复制代码

    从磁盘文件中读取一个通过torch.save()保存的对象。

    参数:

    • f – 字符串,文件名
    • map_location – 一个函数或字典规定如何remap存储位置
    • pickle_module – 用于unpickling元数据和对象的模块 (必须匹配序列化文件时的pickle_module )
x = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file')
print(x2)
复制代码

初始化一个x

将x存储到当前文件夹下并命名为x-file,此时你会发现当前文件夹下边多出来一个同名的文件。 当然打开之后可能不是 0 1 2 3 ,因为编码方式不同,所以不用纠结打开以后看到的是什么。

image.png

声明一个x2再从文件中读回来,会发现结果就是tensor([0, 1, 2, 3]),结果没错就可以了。

y = torch.zeros(4)
torch.save([x, y],'x-file')
x2, y2 = torch.load('x-file')
print(x2, y2)
复制代码
>>
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
复制代码

存储一个张量列表,然后把它们读回内存。

y = torch.zeros(4)
torch.save(y[:2],'x-file')
x2, y2 = torch.load('x-file')
(x2, y2)
print(x2, y2)
复制代码

切片也是可以的。

mydict = {'x': x, 'y': y}
torch.save(mydict, 'x-file')
mydict2 = torch.load('x-file')
mydict2
复制代码
>>
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
复制代码

存储字典也可以。

加载和保存模型参数

深度学习框架提供了内置函数来保存和加载整个网络。

但是是保存模型的参数而不是保存整个模型。

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))
复制代码

还记得我们手写的多层感知机吗,用这个实现一下子。

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
复制代码

现在生成一个net,用它计算X,并将其赋值给Y。

torch.save(net.state_dict(), 'x-file')
复制代码

将net的参数保存起来。

net_ = MLP()
net_.load_state_dict(torch.load('x-file'))
net_.eval()
复制代码

生成一个net_也是多层感知机,net的参数直接加载文件中的参数。

net_.eval()是将模型的模式改为评价模式。

Y_clone = net_(X)
print(Y_clone == Y)
复制代码
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])
复制代码

将新网络赋值给Y_clone,可以看到Y_clone和Y是相同的。

当然换成pytorch自己的层也是可以的

    
MLP = nn.Sequential(nn.Linear(20,256),nn.Linear(256,10),nn.ReLU())

def init(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight)
        nn.init.zeros_(m.bias)


net = MLP
X = torch.randn(size=(2, 20))
Y = net(X)

torch.save(net.state_dict(), 'x-file')

net_ = MLP
net_.load_state_dict(torch.load('x-file'))
net_.eval()

Y_clone = net_(X)
Y_clone == Y
复制代码
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])
复制代码

  1. 《动手学深度学习》系列更多可以看这里:《动手学深度学习》专栏(juejin.cn)

  2. 笔记Github地址:DeepLearningNotes/d2l(github.com)

还在更新中…………

おすすめ

転載: juejin.im/post/7032693546125803533