参与11月更文挑战的第21天,活动详情查看:2021最后一次更文挑战
读写文件有什么必要呢?
读写文件其实不是读取数据集。
是当你的训练时要定期存储中间结果,以确保在服务器电源不小心被断掉,或者出现其他情况的时候,损失掉你前几天的计算结果。
这一节要做的就是如何存储权重向量和整个模型。
import torch
from torch import nn
from torch.nn import functional as F
复制代码
load
和save
对于单个张量,我们可以直接调用load
和save
函数分别读写它们。
-
torch.saves
torch.save(obj, f, pickle_module=<module 'pickle' from '.../pickle.py'>, pickle_protocol=2) 复制代码
参数:
- obj – 保存对象
- f - 字符串,文件名
- pickle_module – 用于pickling元数据和对象的模块
- pickle_protocol – 指定pickle protocal 可以覆盖默认参数
-
torch.load
torch.load(f, map_location=None, pickle_module=<module 'pickle' from '.../pickle.py'>) 复制代码
从磁盘文件中读取一个通过
torch.save()
保存的对象。参数:
- f – 字符串,文件名
- map_location – 一个函数或字典规定如何remap存储位置
- pickle_module – 用于unpickling元数据和对象的模块 (必须匹配序列化文件时的pickle_module )
x = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file')
print(x2)
复制代码
初始化一个x
将x存储到当前文件夹下并命名为x-file
,此时你会发现当前文件夹下边多出来一个同名的文件。 当然打开之后可能不是 0 1 2 3 ,因为编码方式不同,所以不用纠结打开以后看到的是什么。
声明一个x2再从文件中读回来,会发现结果就是tensor([0, 1, 2, 3]),结果没错就可以了。
y = torch.zeros(4)
torch.save([x, y],'x-file')
x2, y2 = torch.load('x-file')
print(x2, y2)
复制代码
>>
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
复制代码
存储一个张量列表,然后把它们读回内存。
y = torch.zeros(4)
torch.save(y[:2],'x-file')
x2, y2 = torch.load('x-file')
(x2, y2)
print(x2, y2)
复制代码
切片也是可以的。
mydict = {'x': x, 'y': y}
torch.save(mydict, 'x-file')
mydict2 = torch.load('x-file')
mydict2
复制代码
>>
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
复制代码
存储字典也可以。
加载和保存模型参数
深度学习框架提供了内置函数来保存和加载整个网络。
但是是保存模型的参数而不是保存整个模型。
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
复制代码
还记得我们手写的多层感知机吗,用这个实现一下子。
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
复制代码
现在生成一个net,用它计算X,并将其赋值给Y。
torch.save(net.state_dict(), 'x-file')
复制代码
将net的参数保存起来。
net_ = MLP()
net_.load_state_dict(torch.load('x-file'))
net_.eval()
复制代码
生成一个net_也是多层感知机,net的参数直接加载文件中的参数。
net_.eval()
是将模型的模式改为评价模式。
Y_clone = net_(X)
print(Y_clone == Y)
复制代码
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
复制代码
将新网络赋值给Y_clone,可以看到Y_clone和Y是相同的。
当然换成pytorch自己的层也是可以的
MLP = nn.Sequential(nn.Linear(20,256),nn.Linear(256,10),nn.ReLU())
def init(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight)
nn.init.zeros_(m.bias)
net = MLP
X = torch.randn(size=(2, 20))
Y = net(X)
torch.save(net.state_dict(), 'x-file')
net_ = MLP
net_.load_state_dict(torch.load('x-file'))
net_.eval()
Y_clone = net_(X)
Y_clone == Y
复制代码
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
复制代码
-
《动手学深度学习》系列更多可以看这里:《动手学深度学习》专栏(juejin.cn)
-
笔记Github地址:DeepLearningNotes/d2l(github.com)
还在更新中…………