pytorch basis - to build the network

Step to build the network is substantially the following:

1. Prepare data

2. Define the network structure model

3. Define loss function
4. Define optimization Optimizer
5. The training
  5.1 tensor form ready for input data and the tag (optional)
  5.2 before calculating the propagation loss is calculated and output network output function Loss
  5.3 backpropagation update the parameters
    of the following three then a no less:
    5.3.1 optimizer.zero_grad () computing gradient values from the previous iteration clear 0
    5.3.2 loss.backward () back-propagation, a gradient value calculated
    5.3.3 optimizer.step () updates the weights parameters
  5.4 save loss and loss of accuracy on the validation set and the training set and print training information. (Optional
6 illustrates changes in the training process and the loss of accuracy (optional)
7 on a test set

 

Code comments are written in great detail 

. 1  Import Torch
 2  Import torch.nn.functional AS F.
 . 3  Import matplotlib.pyplot AS PLT
 . 4  
. 5  # 1. Data Preparation Data Generate 
. 6 X = torch.unsqueeze (torch.linspace (-1,1,100), Dim =. 1 )
 . 7  Print (x.shape)
 . 8 Y = X * X + 0.2 * torch.rand (x.size ())
 . 9  # display data scattergram 
10  plt.scatter (x.data.numpy (), y.data.numpy ())
 . 11  
12 is  # 2 to define the network structure of nET Build 
13 is  class net (torch.nn.Module):
 14      # n_feature: the number of input features n_hidden: the number of hidden layers n_output: the number of output layer
15      DEF  the __init__ (Self, n_feature, n_hidden, n_output):
 16          # Super Net represented inherited parent class, while the parent class parameter initialization 
. 17          Super (Net, Self). The __init__ ()
 18 is          # nn.Linear represents a linear layer represents y = w * x + b where w of a shape of [n_hidden, n_feature] b of the shape of [n_hidden] 
. 19          # Y = w ^ T * X + B where w dimension is the dimension before transposition it is reversed 
20          = self.hidden torch.nn.Linear (n_feature, n_hidden)
 21 is          self.predict = torch.nn.Linear (n_hidden, n_output)
 22 is          Print (self.hidden.weight)
 23 is          Print (self.predict.weight)
 24      # define function prior to the propagation of a 
25     DEF Forward (Self, X):
 26 is          #          n_feature n_hidden n_output 
27          # Example (2, 5). 5. 1 2 
28          #                     - ** - 
29          #              ** - - - ** - - 
30          #                     - ** - - - ** 
31          #              ** - - - ** - - 
32          #                     - ** - 
33 is          #             input layer, an output layer, the hidden layer 
34 is          X = F.relu (self.hidden (X))
 35          X = self.predict (X )
 36          return X
 37 [  #Examples of a network NET 
38 is NET = Net (n_feature =. 1, n_hidden = 10, = n_output. 1 )
 39  Print (NET)
 40  # 3. Loss Function definitions used herein mean square error (Mean Square error) 
41 is loss_func = Torch. nn.MSELoss ()
 42 is  # 4. optimizer defined herein stochastic gradient descent 
43 is optimizer = torch.optim.SGD (net.parameters (), LR = 0.2 )
 44 is  # define updates the display every 300 times 10 times 
45  PLT. Ion ()
 46 is  # 5. The train 
47  for T in Range (100 ):
 48      Prediction = NET (X)      # INPUT ON based Predict X and X 
49     loss = loss_func(prediction, y)     # must be (1. nn output, 2. target)
50     # 5.3反向传播三步不可少
51     optimizer.zero_grad()   # clear gradients for next train
52     loss.backward()         # backpropagation, compute gradients
53     optimizer.step()        # apply gradients
54 
55     if t % 10 == 0:
56         # plot and show learning process
57         plt.cla()
58         plt.scatter(x.data.numpy(), y.data.numpy())
59         plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
60         plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color':  'red'})
61         plt.show()
62         plt.pause(0.1)
63 
64 plt.ioff()

Reference: Mo trouble python

Guess you like

Origin www.cnblogs.com/bob-jianfeng/p/11407955.html