1. Quick build method
In the previous two articles, we built the neural network in this way:
class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x net1 = Net(1, 10, 1) # This is the net1 we built this way
Now there is another way:
net2 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) )
Compared:
print(net1) """ Net ( (hidden): Linear (1 -> 10) (predict): Linear (10 -> 1) ) """ print(net2) """ Sequential ( (0): Linear (1 -> 10) (1): ReLU () (2): Linear (10 -> 1) ) """
(1) net2 does not name layers, but numbers; net1 defines names in __init__
(2) The excitation function of net2 is also included in the layer. In net1, the excitation function is actually called in the forward() function. net1 can personalize your own forward propagation process according to your needs, such as (RNN). If you don't need the process of seventy-eighty-eighty, I believe this form of net2 is more suitable for you.
2. Save the extraction:
2.1 Build and save:
torch.manual_seed(1) # reproducible # fake data x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) def save(): # build network net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_func = torch.nn.MSELoss() # Training for t in range(100): prediction = net1(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step()
torch.save(net1, 'net.pkl') # save the entire network torch.save(net1.state_dict(), 'net_params.pkl') # Only save the parameters in the network (fast, occupy less memory)
2.2 Extract the entire network
def restore_net(): # restore entire net1 to net2 net2 = torch.load('net.pkl') prediction = net2(x) #Extract x in net2 print(net2) print(prediction)
...
2.3 Extract only network parameters
You need to create a new network first, copy the parameters into it, the time is faster than the above one
def restore_params(): # New net3 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) # Copy the saved parameters to net3 net3.load_state_dict(torch.load('net_params.pkl')) prediction1 = net3(x) print(net3) print(prediction1)
2.4 Display of results
# save net1 (1. whole network, 2. only parameters) save() # extract the entire network restore_net() # Extract network parameters, copy to new network restore_params()
3. Results:
Note that Python is whitespace sensitive.
Reference link:
https://morvanzhou.github.io/tutorials/machine-learning/torch/3-04-save-reload/
https://morvanzhou.github.io/tutorials/machine-learning/torch/3-03-fast-nn/