Pytorch 保存和提取训练好的神经网络

在pytorch中,保存神经网络用方法:

torch.save(net, 'net.pkl')

提取神经网络用方法:

torch.load('net.pkl')

保存神经网络有两种方式:

1、保存整个网络

torch.save(net, 'net.pkl')

这种方法能最大程度的保留网络的所有信息,缺点是读取网络时速度稍慢

2、保存网络的状态信息

torch.save(net.state_dict(), 'net_params.pkl')

这种方法只保留网络当前的状态信息,保存和读取速度快,保存的pkl文件体积小,缺点是在读取网络时需要自行先构建网络,否则无法还原信息

示例:

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

x, y = Variable(x).cuda(), Variable(y).cuda()

# 保存网络
def save():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    ).cuda()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    for t in range(300):
        prediction = net(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    plt.figure(1, figsize=(10,3))
    plt.subplot(131)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
    # 保存整个网络
    torch.save(net, 'net.pkl')
    # 保存网络当前的状态
    torch.save(net.state_dict(), 'net_params.pkl')

# 提取整个网络
def restore_net():
    net = torch.load('net.pkl').cuda()
    prediction = net(x)

    plt.figure(1, figsize=(10, 3))
    plt.subplot(132)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)

# 提取网络状态
def restore_params():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    ).cuda()
    net.load_state_dict(torch.load('net_params.pkl'))
    prediction = net(x)

    plt.figure(1, figsize=(10, 3))
    plt.subplot(133)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)

save()
restore_net()
restore_params()

plt.show()

在这里插入图片描述图一为保存的神经网络,图二、三分别为用不同方法提取的神经网络,可以看到,两种提取方式的结果是一致的

发布了208 篇原创文章 · 获赞 841 · 访问量 121万+

猜你喜欢

转载自blog.csdn.net/baishuiniyaonulia/article/details/100039845