神经网络的保存、神经网络提取的2 ways

 1 """
 2 torch: 0.4
 3 matplotlib
 4 神经网络的保存 
 5 神经网络提取的2 ways
 6 """
 7 import torch
 8 import matplotlib.pyplot as plt
 9 
10 
11 
12 # fake data
13 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
14 y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
15 
16 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
17 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
18 
19 # save net1
20 def save():
21     # 建立网络实例net1
22     net1 = torch.nn.Sequential(
23         torch.nn.Linear(1, 10),
24         torch.nn.ReLU(),
25         torch.nn.Linear(10, 1)
26     )
27     optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
28     loss_func = torch.nn.MSELoss()
29     #训练
30     for t in range(100):
31         prediction = net1(x)
32         loss = loss_func(prediction, y)
33         optimizer.zero_grad()
34         loss.backward()
35         optimizer.step()
36 
37     # plot result
38     plt.figure(1, figsize=(10, 3))
39     plt.subplot(131)
40     plt.title('Net1')
41     plt.scatter(x.data.numpy(), y.data.numpy())
42     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
43 
44     # 2 ways to save the net
45     torch.save(net1, 'net.pkl')  # save entire net
46     torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters
47 
48 
49 def restore_net():
50     # restore entire net1 to net2
51     net2 = torch.load('net.pkl')
52     prediction = net2(x)
53 
54     # plot result
55     plt.subplot(132)
56     plt.title('Net2')
57     plt.scatter(x.data.numpy(), y.data.numpy())
58     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
59 
60 
61 def restore_params():
62     # restore only the parameters in net1 to net3
63     net3 = torch.nn.Sequential(
64         torch.nn.Linear(1, 10),
65         torch.nn.ReLU(),
66         torch.nn.Linear(10, 1)
67     )
68 
69     # copy net1's parameters into net3
70     net3.load_state_dict(torch.load('net_params.pkl'))
71     prediction = net3(x)
72 
73     # plot result
74     plt.subplot(133)
75     plt.title('Net3')
76     plt.scatter(x.data.numpy(), y.data.numpy())
77     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
78     plt.show()
79 
80 # save net1
81 save()
82 
83 # restore entire net (may slow)
84 restore_net()
85 
86 # restore only the net parameters
87 restore_params()

猜你喜欢

转载自www.cnblogs.com/xuechengmeigui/p/12388514.html