李沐老师-pytorch-读写文件

#文件读取

import torch
from torch import nn
from torch.nn import functional as F

x=torch.arange(4)
torch.save(x,'x-file')
#当前目录下新建文件
x2=torch.load("x-file")
print(x2)

输出:tensor([0, 1, 2, 3])

y=torch.zeros(4)
torch.save([x,y],"x-file")
x2,y2=torch.load("x-file")
print((x2,y2))

输出:(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

#写入或读取从字符串映射到张量的字典
mydict={'x':x,'y':y}
torch.save(mydict,"mydict")
mydict2=torch.load("mydict")
print(mydict2)

输出:{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

#加载和保存模型参数:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden=nn.Linear(20,256)
        self.output=nn.Linear(256,10)

    def forward(self,x):
        return self.output(F.relu(self.hidden(x)))

net=MLP()
X=torch.randn(size=(2,20))
Y=net(X)

torch.save(net.state_dict(),"mlp.params")
#实例化了原始多层感知机模型的一个备份.直接读取文件中存储的参数.
clone=MLP()
clone.load_state_dict(torch.load("mlp.params"))   #把存在磁盘上的参数写回网络
clone.eval()

Y_clone=clone(X)
print(Y_clone==Y)

输出:tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

猜你喜欢

转载自blog.csdn.net/qq_45828494/article/details/126611293
今日推荐