pytorch之 sava_reload_model

 1 import torch
 2 import matplotlib.pyplot as plt
 3 
 4 # torch.manual_seed(1)    # reproducible
 5 
 6 # fake data
 7 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
 8 y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 9 
10 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
11 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
12 
13 
14 def save():
15     # save net1
16     net1 = torch.nn.Sequential(
17         torch.nn.Linear(1, 10),
18         torch.nn.ReLU(),
19         torch.nn.Linear(10, 1)
20     )
21     optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
22     loss_func = torch.nn.MSELoss()
23 
24     for t in range(100):
25         prediction = net1(x)
26         loss = loss_func(prediction, y)
27         optimizer.zero_grad()
28         loss.backward()
29         optimizer.step()
30 
31     # plot result
32     plt.figure(1, figsize=(10, 3))
33     plt.subplot(131)
34     plt.title('Net1')
35     plt.scatter(x.data.numpy(), y.data.numpy())
36     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
37 
38     # 2 ways to save the net
39     torch.save(net1, 'net.pkl')  # save entire net
40     torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters
41 
42 
43 def restore_net():
44     # restore entire net1 to net2
45     net2 = torch.load('net.pkl')
46     prediction = net2(x)
47 
48     # plot result
49     plt.subplot(132)
50     plt.title('Net2')
51     plt.scatter(x.data.numpy(), y.data.numpy())
52     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
53 
54 
55 def restore_params():
56     # restore only the parameters in net1 to net3
57     net3 = torch.nn.Sequential(
58         torch.nn.Linear(1, 10),
59         torch.nn.ReLU(),
60         torch.nn.Linear(10, 1)
61     )
62 
63     # copy net1's parameters into net3
64     net3.load_state_dict(torch.load('net_params.pkl'))
65     prediction = net3(x)
66 
67     # plot result
68     plt.subplot(133)
69     plt.title('Net3')
70     plt.scatter(x.data.numpy(), y.data.numpy())
71     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
72     plt.show()
73 
74 # save net1
75 save()
76 
77 # restore entire net (may slow)
78 restore_net()
79 
80 # restore only the net parameters
81 restore_params()

猜你喜欢

转载自www.cnblogs.com/dhName/p/11742959.html
今日推荐