Pytorch学习(五)--逻辑回归

教程视频地址:https://www.bilibili.com/video/BV15E411W7zY

背景
之前我们学习了如何建立、训练一个线性模型,通过数据集(1,2)(2,4)(3,6)让模型知道我们要构建的是y=2x+0这样一个函数。于是我们最后我们输入x=4, 预测y=8,训练成功。
但往往问题并没有这么简单。比如我们要在MNIST数据集上识别数字,如下图。我们发现7和9的形状十分接近,那么如何来进行区分,它到底是7还是9呢?我们就需要判断这个数字它是7的可能性更大,还是它是9的可能性更大。这个时候你也发现了,我们这是要输出的,并不是某一个确定的数字(比如x=4,y=8),而是他是某个数字的概率(p(y=7)=0.1,p(y=9)=0.9),然后选择一个概率最大的数字为它进行归类。
在这里插入图片描述
如何实现逻辑回归?
逻辑回归(Logistic Regression)是一种用于解决二分类(0 or 1)问题的机器学习方法,用于估计某种事物的可能性。那么他如何实现呢?
还用我们之前的y=wx+b的模型来举例,之前的输出y属于实数集合R,现在我们要输出一个一个概率,也就是在区间[0,1]之间。自然我们就想到需要找出一个映射,把我们之前的输出集合R映射到区间[0,1],他就是函数Sigma,图像和函数在下图,这样我们就轻松的实现了实数集合到0~1之间的映射。
在这里插入图片描述
同时损失函数也要做修改,如下图
在这里插入图片描述
pytorch代码实现:

import  torch
import  torch.nn.functional as F
import  numpy as np
import matplotlib.pyplot as plt

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[0],[0],[1]])

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
    def forward(self, x):
        y_pred = F.sigmoid(self.linear(x))#这里需要把原来的输出y传给sigmoid,即实现的区间的映射
        return  y_pred

model = LinearModel()

criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view(200,1)#将数据变成一个二百行一列的矩阵
y_t = model(x_t)
y = y_t.data.numpy()

plt.plot(x,y)
plt.plot([0,10],[0.5,0.5],c='r')
plt.ylabel('probablility of pass')
plt.xlabel('hours')
plt.grid()#画出网格
plt.show()

结果:
0 4.10068941116333
1 4.011950969696045
2 3.9253454208374023
3 3.8409411907196045
4 3.7588043212890625
5 3.6789965629577637
6 3.601571559906006
7 3.526580810546875
8 3.4540674686431885
9 3.384068489074707
10 3.3166117668151855
……
990 1.0986469984054565
991 1.098111629486084
992 1.09757661819458
993 1.0970423221588135
994 1.0965087413787842
995 1.0959757566452026
996 1.0954433679580688
997 1.0949115753173828
998 1.0943803787231445
999 1.0938501358032227

训练1000次后的预测图:
在这里插入图片描述

发布了10 篇原创文章 · 获赞 0 · 访问量 129

猜你喜欢

转载自blog.csdn.net/weixin_44841652/article/details/105088757