.pkl保存模型
有两种模型:一个是保存模型,一个是保存模型的所有参数
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1,1,100),dim = 1) #压缩为2维,因为torch 中 只会处理二维的数据
y = x.pow(2) + 0.2 * torch.rand(x.size())
x,y = Variable(x),Variable(y)# 神经网络中只用Variable的方法
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show() # 散点图
class Net(torch.nn.Module): # 继承 torch 的 Module
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__() # 继承 __init__ 功能
# 定义每层用什么样的形式
self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层线性输出
self.predict = torch.nn.Linear(n_hidden, n_output) # 输出层线性输出
def forward(self, x): # 这同时也是 Module 中的 forward 功能
# 正向传播输入值, 神经网络分析出输出值
x = F.relu(self.hidden(x)) # 激励函数(隐藏层的线性值)
x = self.predict(x) # 输出值
return x
# net = Net(n_feature=1, n_hidden=10, n_output=1)
def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(), # F 的relu 是一个function, torch.nn的ReLU是class,性质不同
torch.nn.Linear(10,1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.2) # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss() # 预测值和真实值的误差计算公式 (均方差)
for t in range(100):
prediction = net1(x) # 喂给 net 训练数据 x, 输出预测值
loss = loss_func(prediction, y) # 计算两者的误差
optimizer.zero_grad() # 清空上一步的残余更新参数值
loss.backward() # 误差反向传播, 计算参数更新值
optimizer.step() # 将参数更新值施加到 net 的 parameters 上
plt.subplot(1,3,1) # ??
plt.scatter(x.data.numpy(), y.data.numpy())
plt.title('Net1')
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', 'lw')
# plt.show()
torch.save(net1,'net.pkl')
torch.save(net1.state_dict(), 'net_params.pkl') # 保存的是所有的参数,更快速
def load_modul():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot(1,3,2) #??
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-','lw')
# plt.show() # 如果想同时显示,必须在最后一个图加show()
def load_modul2():
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(), # F 的relu 是一个function, torch.nn的ReLU是class,性质不同
torch.nn.Linear(10, 1)
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
plt.subplot(1,3,3) # ??
plt.scatter(x.data.numpy(), y.data.numpy())
plt.title('Net3')
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', 'lw')
plt.show()
save()
load_modul()
load_modul2()