Regression 关系拟合 (回归)

1.神经网络主要分为回归跟分类,回归也就是说输出值是连续的,而分类的输出值是离散的。接下来就看看如何用神经网络的代码来实现回归。

2.代码

# 导入一些相应的包
#解决python2 和 python 3之间一些输出格式的不同,一律使用python 3
from __future__ import print_function 
import torch
from torch.autograd import Variable
#一些相应的激励函数
import torch.nn.functional as F
#用来画图
import matplotlib.pyplot as plt

#fake data
#创建数据集
#unsqueeze 的作用是给数据加上维度,因为linspace产生的数据是一维的
#然而神经网络之中需要二维的,所以使用unsqueese增加维度。[]-->[[]]
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
#为了使得数据更加真实,使用rand随机函数增加噪点
y=x.pow(2)+0.2*torch.rand(x.size())
#千万注意数据创建完后要放入Variable中才能进行反向传播
x=Variable(x)
y=Variable(y)

#搭建神经网络
#神经网络的两个基本模块,基本是固定的套路
#使用类的方式创建网络,继承torch.nn.Module
#__init__进行初始化,定义网络的一些属性
#forward 定义各个层次之间的关系
class Net(torch.nn.Module):
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden=torch.nn.Linear(n_feature,n_hidden)
        self.predict=torch.nn.Linear(n_hidden,n_output)

    def forward(self,x):
        #结果使用激活函数进行激活
        x=F.relu(self.hidden(x))
        #最后的结果不再需要使用激活函数激活,因为y的结果可能是整个实数集,
        #如果使用激活函数则结果可能被限制在某个区域。
        x=self.predict(x)
        return x
net=Net(1,10,1)
# net=Net(n_feature=1,n_hidden=10,n_output=1)
# print(net)两种方式都可以输出网络的结构

#优化函数,这里我们使用随机梯度下降
optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
#使用均方差作为损失函数
loss_func=torch.nn.MSELoss()


#训练神经网络并进行可视化
#ion() 用来画动态图
plt.ion()

for t in range(100):
    prediction=net(x)
    loss=loss_func(prediction,y)

    #对梯度进行清零
    optimizer.zero_grad()
    #误差反向传播,进行参数的更新
    loss.backward()
    #将参数更新值添加到net中的parameters上
    optimizer.step()
    if t % 5 == 0:
        #清楚轴线
        plt.cla()
        #画出散点图
        plt.scatter(x.data.numpy(), y.data.numpy())
        #画出生成的拟合曲线
        plt.plot(x.data.numpy(), prediction.data.numpy(), c='red', lw=5)
        # plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
        #暂停0.1s
        plt.pause(0.1)
#调用show()之前,需要使用ioff
#如果在脚本中使用ion()命令开启了交互模式,没有使用ioff()关闭的话,则图像会一闪而过,并不会常留。#要想防止这种情况,需要在plt.show()之前加上ioff()命令。
plt.ioff()
plt.show()

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 

猜你喜欢

转载自blog.csdn.net/xs_211314/article/details/82413767