干货|10分钟入门PyTorch(2)~附源码

10分钟入门PyTorch(2)

上一节介绍了简单的线性回归10分钟快速入门PyTorch(1),如何在pytorch里面用最小二乘来拟合一些离散的点,这一节我们将开始简单的logistic回归,介绍图像分类问题,使用的数据是手写字体数据集MNIST。

1

logistic回归

logistic回归简单来说和线性回归是一样的,要做的运算同样是 y = w * x + b。
logistic回归简单的是做二分类问题,使用sigmoid函数将所有的正数和负数都变成0-1之间的数,这样就可以用这个数来确定到底属于哪一类,可以简单的认为概率大于0.5即为第二类,小于0.5为第一类。
干货|10分钟入门PyTorch(2)~附源码

这就是sigmoid的图形
干货|10分钟入门PyTorch(2)~附源码

而我们这里要做的是多分类问题,对于每一个数据,我们输出的维数是分类的总数,比如10分类,我们输出的就是一个10维的向量,然后我们使用另外一个激活函数,softmax
干货|10分钟入门PyTorch(2)~附源码
这就是softmax函数作用的机制,其实简单的理解就是确定这10个数每个数对应的概率有多大,因为这10个数有正有负,所以通过指数函数将他们全部变成正数,然后求和,然后这10个数每个数都除以这个和,这样就得到了每个类别的概率。

data

首先导入torch里面专门做图形处理的一个库,torchvision,根据官方安装指南,你在安装pytorch的时候torchvision也会安装。
我们需要使用的是torchvision.transforms和torchvision.datasets以及torch.utils.data.DataLoader

首先DataLoader是导入图片的操作,里面有一些参数,比如batch_size和shuffle等,默认load进去的图片类型是PIL.Image.open的类型,如果你不知道PIL,简单来说就是一种读取图片的库

torchvision.transforms里面的操作是对导入的图片做处理,比如可以随机取(50, 50)这样的窗框大小,或者随机翻转,或者去中间的(50, 50)的窗框大小部分等等,但是里面必须要用的是transforms.ToTensor(),这可以将PIL的图片类型转换成tensor,这样pytorch才可以对其做处理

torchvision.datasets里面有很多数据类型,里面有官网处理好的数据,比如我们要使用的MNIST数据集,可以通过torchvision.datasets.MNIST()来得到,还有一个常使用的是torchvision.datasets.ImageFolder(),这个可以让我们按文件夹来取图片,和keras里面的flow_from_directory()类似,具体的可以去看看官方文档的介绍。
干货|10分钟入门PyTorch(2)~附源码

以上就是我们对图片数据的读取操作

model

之前讲过模型定义的框架,废话不多说,直接上代码
干货|10分钟入门PyTorch(2)~附源码
我们需要向这个模型传入参数,第一个参数定义为数据的维度,第二维数是我们分类的数目。

接着我们可以在gpu上跑模型,怎么做呢?
首先可以判断一下你是否能在gpu上跑

干货|10分钟入门PyTorch(2)~附源码
如果返回True就说明有gpu支持
接着你只需要一个简单的命令就可以了

干货|10分钟入门PyTorch(2)~附源码

或者

干货|10分钟入门PyTorch(2)~附源码

都可以
然后需要定义loss和optimizer

干货|10分钟入门PyTorch(2)~附源码

这里我们使用的loss是交叉熵,是一种处理分类问题的loss,optimizer我们还是使用随机梯度下降

train

接着就可以开始训练了

干货|10分钟入门PyTorch(2)~附源码
干货|10分钟入门PyTorch(2)~附源码

注意我们如果将模型放到了gpu上,相应的我们的Variable也要放到gpu上,也很简单

干货|10分钟入门PyTorch(2)~附源码

然后可以测试模型,过程与训练类似,只是注意要将模型改成测试模式

干货|10分钟入门PyTorch(2)~附源码

这是跑完100 epoch的结果

干货|10分钟入门PyTorch(2)~附源码

具体的结果多久打印一次,如何打印可以自己在for循环里面去设计

这一部分我们就讲解了如何用logistic回归去做一个简单的图片分类问题,知道了如何在gpu上跑模型,下一节我们将介绍如何写简单的卷积神经网络,不了解卷积网络的同学可以先去我的专栏看看之前卷积网络的介绍。

本文代码已经上传到了github上
欢迎查看我的知乎专栏,深度炼丹
欢迎访问我的博客

推荐阅读文章:

10分钟快速入门PyTorch(1)
10分钟入门pytorch(0)
隐马尔科夫模型-前向算法

全是通俗易懂的硬货!只需置顶~欢迎关注交流~

干货|10分钟入门PyTorch(2)~附源码

猜你喜欢

转载自blog.51cto.com/15009309/2553589