Pytorch save and retrieve trained neural network

In pytorch, the preservation of neural network method:

torch.save(net, 'net.pkl')

Extraction neural network method:

torch.load('net.pkl')

Save the neural network in two ways:

1, save the entire network

torch.save(net, 'net.pkl')

In this way all the information to the greatest degree of retention of the network, the network disadvantage is read slower

2, save the network state information

torch.save(net.state_dict(), 'net_params.pkl')

This method is only to keep the current network status information, save and fast read speed, save the file pkl small size disadvantage is that when reading the network needs to build its own network, or can not restore information

Example:

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

x, y = Variable(x).cuda(), Variable(y).cuda()

# 保存网络
def save():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    ).cuda()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    for t in range(300):
        prediction = net(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    plt.figure(1, figsize=(10,3))
    plt.subplot(131)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
    # 保存整个网络
    torch.save(net, 'net.pkl')
    # 保存网络当前的状态
    torch.save(net.state_dict(), 'net_params.pkl')

# 提取整个网络
def restore_net():
    net = torch.load('net.pkl').cuda()
    prediction = net(x)

    plt.figure(1, figsize=(10, 3))
    plt.subplot(132)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)

# 提取网络状态
def restore_params():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    ).cuda()
    net.load_state_dict(torch.load('net_params.pkl'))
    prediction = net(x)

    plt.figure(1, figsize=(10, 3))
    plt.subplot(133)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)

save()
restore_net()
restore_params()

plt.show()

Here Insert Picture DescriptionFIG stored as a neural network, Figure II, the neural network are three different extraction methods, can be seen, the results of both extraction method is the same

Published 208 original articles · won praise 841 · Views 1.21 million +

Guess you like

Origin blog.csdn.net/baishuiniyaonulia/article/details/100039845