《nlp入门+实战:第七章:pytorch中数据集加载和自带数据集的使用》


上一篇: 《nlp入门+实战:第六章:常见优化器算法的介绍》

本章代码链接:

本章数据集地址:

1.模型中使用数据加载器的目的

在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。

但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。

所以,接下来我们来学习pytorch中的数据加载的方法

2.数据集类

2.1 Dataset基类介绍

在torch中提供了数据集的基类torch.utils.data. Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。

torch.utils.data. Dataset的源码如下:
在这里插入图片描述

可知:我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:

  • 1._len_方法,能够实现通过全局的len()方法获取其中的元素个数
  • 2._getitem_方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据

2.2 数据加载案例

下面通过—个例子来看看如何使用Dataset来加载数据

数据来源:https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection/
不过这个数据集好像已经不能用了,但是类似的Kaggle上还有很多,大家也可以尝试下载这一个:
https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset

数据介绍: SMS Spam Collection是用于强扰短信识别的经典数据集,完全来自真实短信内容,包括4831条正常短信和747条骚扰短信。正常短信和骚扰短信保存在一个文本文件中。每行完整记—条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信

数据实例:
在这里插入图片描述

实现如下:

import torch
from torch.utils.data import Dataset

data_path = r"C:\Users\NineSun\Desktop\archive\spam.txt"  # r表示后面是一个字符串,无需转义


# 完成数据集类
class MyDdataSet(Dataset):
    def __init__(self):
        self.lines = open(data_path, 'r', encoding='UTF-8').readlines()

    def __getitem__(self, index):
        # 获取索引对应位置的数据
        return self.lines[index]

    def __len__(self):
        # 返回数据的总数量
        return len(self.lines)


if __name__ == '__main__':
    my_dataset = MyDdataSet()
    print(my_dataset[0])
    print(len(my_dataset))

在这里插入图片描述

之后对Dataset进行实例化。可以迭代获取其中的数据:

my_dataset = MyDdataSet()
for i in range(len(my_dataset)):
    print(i,my_dataset[i])

3.迭代数据集

使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:

  • 批处理数据(Batching the data)
  • 打乱数据(Shuffling the data)
  • 使用多线程multiprocessing并行加载数据。

在pytorch中torch.utils.data. DataLoader提供了上述的所用方法DataLoader的使用方法示例:

from torch.utils.data import DataLoader

my_dataset = MyDdataSet()
# dataset:实例化之后的数据集;batch_size:batch的大小,一个batch包含10个样本数据;
# shuffle:表示是否打乱数据的顺序;num_workers:表示加载数据时启用线程的数量
data_loader = DataLoader(dataset=my_dataset, batch_size=10, shuffle=True, num_workers=2)
# 遍历,获取每个batch的结果
if __name__ == '__main__':
    for i in data_loader:
        print(i)

其中参数含义:

  • 1.dataset:提前定义的dataset的实例
  • 2.batch_size:传入数据的batch的大小,常用128,256等等
  • 3.shuffle: bool类型,表示是否在每次获取数据的时候提前打乱数据
  • 4.num_workers :加载数据的线程数

注意:

  • 1.len(dataset)=数据集的样本数
  • 2.1en(dataloader) = math.cei1(样本数/batch_size即向上取整)
    print(len(data_loader))
    print(len(my_dataset))

在这里插入图片描述

4 pytorch自带的数据集

pytorch中自带的数据集由两个上层api提供,分别是torchvision和torchtext;
其中:

  • 1.torchvision提供了对图片数据处理相关的api和数据
    • 数据位置:torchvision.datasets,例如: torchvision.datasets.MNIST(手写数字图片数据)
  • 2.torchtext提供了对文本数据处理相关的API和数据
    • 数据位置: torchtext.datas ets ,例如: torchtext.datasets.IMDB(电影评论文本数据)

下面我们以Mnist手写数字为例,来看看pytorch如何加载其中自带的数据集

使用方法和之前一样:

  • 1.准备好Dataset实例
  • 2.把dataset交给dataloder 打乱顺序,组成batch

在进行下面的内容之前,请大家把torchvision和torchtext安装一下,安装办法也很简单,我使用的是anaconda安装的,安装命令如下:

pip install torchvision
pip install torchtext

记得切换到你所对应的环境
在这里插入图片描述

如果上面这种方式下载过慢,可以尝试下面这条指令

pip install --upgrade torch torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple 

pip install --upgrade torch torchtext -i https://pypi.tuna.tsinghua.edu.cn/simple 

如果没安装成功,请大家自己百度一下吧。

4.1 torchversion.datasets

torchversoin. datasets 中的数据集类(比如torchvision.datasets .MNIST),都是继承自Dataset
意味着:直接对torchvision.datasets.MNIST进行实例化就可以得到Dataset的实例但是MNIST API中的参数需要注意一下:

torchvision.datasets.MNIST(root=’ /files/’ ,train=True,download=True,transform=)

  • 1.root参数表示数据存放的位置
  • 2.train: bool类型,表示是使用训练集的数据还是测试集的数据
  • 3.download:bool类型,表示是否需要下载数据到root目录
  • 4.transform:实现的对图片的处理函数

4.2 MNIST数据集的介绍

数据集的原始地址: http://yann.lecun.com/exdb/mnist/

MNIST是由Yann Lecun等人提供的免费的图像识别的数据集,其中包括60000个训练样本和10000个测试样本,其中图片的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28×28

执行代码,下载数据,观察数据类型:

from torchvision.datasets import MNIST

mninst = MNIST(root='./data', train=True, download=True)
print(mninst[0])

运行以后,会在data目录下生成以下数据集
在这里插入图片描述

在这里插入图片描述

可以看出其中数据集返回了两条数据,可以猜测为图片的数据和目标值

返回值的第0个为Image类型,可以调用show()方法打开,发现为手写数字5

from torchvision.datasets import MNIST

mninst = MNIST(root='./data', train=True, download=True)
print(mninst[0])
img=mninst[0][0]
img.show()

在这里插入图片描述

由上可知:返回值为(图片,目标值),这个结果也可以通过观察源码得到

下一篇:《nlp入门+实战:第八章:使用Pytorch实现手写数字识别》

猜你喜欢

转载自blog.csdn.net/zhiyikeji/article/details/126062699