学习pytorch(二)pytorch自带的工具实现简单的线性模型

在上一篇博客中,我是用手动的方式管理和更新权重的。在pytorch中,这些其实可以自动完成。下面分享下用pytorch构建简单模型并训练的学习收获。
有4个步骤。
1.获得数据集
2.构建模型(这里用pytorch自带的单元链接拼凑成一个模型。
3.构建损失计算器和权值优化器。损失计算器用来计算模型得到的预测值相对于真实值的损失。优化器用来调整权值,通过权值调整,使得模型能够逐渐实现我们的目的。损失计算器和权值优化器都是pytorch自带的。
4.开始训练。

import torch
x_data=torch.tensor([[1.0],[2.0],[3.0]])#3x1的矩阵。2阶张量。每一行就是一个样本,每个样本包含1个特征。
y_data=torch.tensor([[2.0],[4.0],[6.0]])

class LinearModel(torch.nn.Module):#建立自己的模型,这个模型继承自torch.nn中的模型类。这个类已经有了神经网络模型
    #的很多基本方法,我们只需要根据我们的需求修改这些方法。
    def __init__(self):
        super().__init__()#构造函数直接调用父类的构造函数。
        self.linear=torch.nn.Linear(1,1)#定义计算单元,这里用pytorch自带的线性运算单元Linear。创建这个对象
        #需要提供两个参数,第一个1表示输入中每个样本包含的特征值数量。这里是1个。
        # 第二个1表示输出中每个样本包含的特征值数量。这里是1个。
    def forward(self,x):#定义前馈如何进行。
        y_hat=self.linear(x)#直接进行线性前馈。linear这个对象由于定义了__call__方法。
        #让他变成了可调用对象。直接调用 这个对象,就会执行self.__call__()方法。在linear中,
        #self.__call__(x)的核心就是调用self.forward()。
        return y_hat
    #这就创建了一个非常简单的模型了。
model=LinearModel() #实例化刚刚创建的模型。
criterion = torch.nn.MSELoss(size_average=False)#这个是自带的MSE损失计算器。实例化它。size_average这个参数
#指示要不要对多个样本的损失和除以样本数。(也是返回的是和还是均值。false返回的是和)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)#这个是自带的优化器。能够对第一个参数中的tensor进行梯度调整。lr为学习
#率。model.parameters返回模型中所有需要计算梯度的tensor。

for cycle in range(1000):#开始训练,训练次数为1000轮。
    y_hat = model(x_data) #model这个对象也是个可调用对象,调用它的核心是调用forward方法。
    loss=criterion(y_hat,y_data) #计算损失。
    print(cycle,loss.item())# 打印损失,items方法可以且仅可以把单元素tensor的值。loss就是只有一个值的tensor。

    optimizer.zero_grad()# 把优化器管理的所有tensor的梯度清0。
    loss.backward() # 反馈计算各个要求得到梯度的节点的梯度。
    optimizer.step() #优化器开始工作,根据梯度调整其管理的各个tensor的值。

print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())

x_test= torch.Tensor([[4.0],[5.0]])#这是测试集。
y_test= model(x_test)#得到预测值。
print('y_hat=',y_test.detach().numpy().tolist())#以列表的方式打印预测值。y_test.datach()作用于.data相同,返回一个不计算
#的tensor。返回的tensor于原来的tensor同地址。.numpy()是不计算梯度tensor的一个方法,返回一个numpy对象。这个numpy于调用它的tensor
#的数据同地址。.tolist()把numpy转为python的列表。列表地址不相同。

运行结果
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/ld_long/article/details/113638887