莫烦pytorch(4)——regression

1.构造出散点图

import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

#x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)## x data (tensor), shape=(100, 1),这里的是把一维变成2维
x=torch.linspace(-1,1,100)[:,np.newaxis]#这句话与上句作用一致

y=x.pow(2)+0.3*torch.rand(x.shape)  #这里用x.size()也一样.torch.rand类似于np.random.rand()


plt.scatter(x,y)
plt.show()

在这里插入图片描述

2.构造regression类

class Net(torch.nn.Module):
    def __init__(self,n_features,n_hidden,n_output):
        super(Net,self).__init__()#继承Net
        #这几行都是固定的模式
        self.hidden=torch.nn.Linear(n_features,n_hidden)  #(n_features,n_hidden)
        self.predict=torch.nn.Linear(n_hidden,n_output)   #(n_hidden,n_output)

    def forward(self,x):
        x=F.relu(self.hidden(x))
        x=self.predict(x)
        return x

3.计算loss,画图

net=Net(1,10,1)
print(net)
optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
loss_func=torch.nn.MSELoss()

fig=plt.figure()
plt.ion()
for t in range(200):
    prediction=net(x)
    loss=loss_func(prediction,y)        # 计算两者的误差
    optimizer.zero_grad()               # 清空上一步的残余更新参数值
    loss.backward()                     #BP
    optimizer.step()                    #更新参数
    if t%5==0:
        plt.cla()
        plt.scatter(x.data.numpy(),y.data.numpy())
        plt.plot(x.data.numpy(),prediction.data.numpy(),'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)

plt.ioff()  # 画图
plt.show()

这里的optimizer=torch.optim.SGD(net.parameters(),lr=0.2)是optim库中的stochanstic GradientDescent(随机梯度下降),后面会提到。loss_func=torch.nn.MSELoss()这是Mean Square Error(均方差,用最小二乘法求得)。

这里还需要解释一下plt.ion(),因为matplotlib库在block模式的时候只显示一张静态图,要显示动态的图必须要使用这句话转换到交互模式。
在交互模式下:

1、plt.plot(x)或plt.imshow(x)是直接显示图像

2、使用plt.ion()最好要用plt.ioff()关闭,如果在交互模式中,不使用plt.ioff()会让图片一闪而过

在阻塞模式下:

1、一张图必须关闭才能出现下一张。

2、plt.plot(x)plt.imshow(x)是不能直接出图像,需要plt.show()后才能显示图像

猜你喜欢

转载自blog.csdn.net/qq_42738654/article/details/87955060
今日推荐