[PyTorch practice] Linear regression of zero-based entry PyTorch [line-by-line code explanation]

foreword

This section contains a small example of getting started with torch

The line-by-line code video explanation is at: https://www.bilibili.com/video/BV1nS4y1u76S?spm_id_from=333.999.0.0

Mainly about linear regression examples

  • Create your own data

  • Build a Linear Regression Model

  • complete the training process

  • Drawing display

Linear model

94a77f1d8fdaa1c1053e4965aa4c275a.pngwhere k is the weight and b is the bias term.

In general, the linear model is to fit the k and b of which are actually w and bias

like here

33c9ce863f6896e9431a164a90b70152.png

code as a whole

Simulation data

Add Gaussian white noise (a group of random numbers that conform to a normal distribution with a mean of 0 and a variance of 1), and set x to 512 points, that is, the number of samples is 512

edca9992d8ea27bd93beb8714b158d8e.png

Linear model

Because each value of the input and each value of the output is actually a dimension of 1, feature_num=1, the linear model is

class LinearModel(nn.Module):
    def __init__(self, in_fea, out_fea):
        super(LinearModel, self).__init__()
        self.out = nn.Linear(in_fea, out_fea)
    def forward(self, x):
        x = self.out(x)
        return x
7b87ac96e286b0cb41ffb446066491a9.png

Define loss function and optimizer

optimizer = torch.optim.SGD(model.parameters(), lr=0.02)

loss_func = nn.MSELoss()

Change the dimension of the data to the model input dimension

62db9f00bd1215d9e37b1dc2a7728a86.pngAdd one dimension to the dimension of the feature

training and visualization

The routine is

  1. forward reasoning

  2. 算loss

  3. clear gradient

  4. backpropagation

  5. update weights

plt.ion()
for step in range(200):
    prediction = model(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step%10 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.xlim(0,1.1)
        plt.ylim(0, 20)
        [w, b] = model.parameters()
        plt.text(0, 0.5, 'loss=%.4f, k=%.2f, b=%2f'%(loss.item(), w.item(), b.item() ),fontdict={'size': 20, 'color':  'red'})
        plt.pause(0.5)
plt.ioff()
plt.show()

Among them, ion is to open the interactive mode, and ioff is to close the interactive mode, and you can dynamically draw the picture to see the changes.

Dynamically look at the fitted line transformation

1d0c9107f38b4beb288412bf807596cc.png 0c0a9120f2fd09e7fa31dd8c9c4878f0.png 89ab799498523a521e86d7662f67cae9.png

Recommended reading:

My 2022 Internet School Recruitment Sharing

My 2021 Summary

Talking about the difference between algorithm post and development post

Internet school recruitment research and development salary summary

For time series, everything you can do.

What is the spatiotemporal sequence problem? Which models are mainly used for such problems? What are the main applications?

Public number: AI snail car

Stay humble, stay disciplined, stay progressive

91180d846d68ee78da035e9b1ad376da.png

Send [Snail] to get a copy of "Hands-on AI Project" (AI Snail Car)

Send [1222] to get a good leetcode brushing note

Send [AI Four Classics] to get four classic AI e-books

Guess you like

Origin blog.csdn.net/qq_33431368/article/details/123516036