pytorch读取数据

 Pytorch的torchvision包中提供了一些数据集,比如MNIST,CIFAR-10…今天分享一下怎样去读取torchvision中提供的数据集。

import torchvision 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


'''
参数说明
root:保存数据的目录(自定义)
train:是否作为训练集(train_set为True, test_set为False)
transform:数据进行怎样的处理,transforms.ToTensor()表示由 PIL -> tensor
download:是否进行下载,若没有数据集就会进行下载,若检测到有的话,直接使用(建议设置为True,始终不会出错)
test_set中存放的是一堆images和对应targets的集合
'''
# train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=transforms.ToTensor(),download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=transforms.ToTensor(),download=True)

'''
参数说明
dataset:要读取的数据集
batch_size:每次读多少张图片
shuffle:读取完成一次后是否进行打乱处理(一般设置为True)
num_workers:多少个子进程用来读取数据,0代表只是用main process来读取数据
drop_last:当读取到最后的时候,图片的数量不足batch_size大小,怎么处理,True表示舍弃,False表示读取
'''
test_dataloader = DataLoader(dataset=test_set, batch_size=128, shuffle=True, num_workers=0, drop_last=False)

img, target = test_set[0]

print(img)
print(target)
print(test_set.classes)
print(type(test_set))

'''
对下面epochs和step的说明
epochs:读取数据集的遍数
step:读一遍数据集要用多少次
举个栗子:
老师新接手了一个班级,肯定需要熟悉班里的同学
为了更快地熟悉,他每次看batch_size大小的同学,全部看完一遍,他需要step次
但是只看一边肯定记不住,肯定需要看epochs遍,
每一遍看的时候,他都会以一个batch_size为小组,看一个batch_size的同学
'''
epochs = 2
for epoch in range(epochs):
    step = 0
    for data in test_dataloader:
        imgs, targets = data
        print(imgs.shape)
        print(targets)
        print("Epoch: {}/{} ,Step: {}".format(epoch+1, epochs, step))
        step = step + 1

看一下下面的运行结果:
在这里插入图片描述
我这里batch_size设置为128,drop_last=False,即每次读取的时候,读取128张,并且最后不足128张时,仍然读取。可以看出,在第二轮读取数据的时候,最后一步,即step=78时,读取了16张图片,而倒数第二步,即step=77时,读取了128张
附:一个讲的很好的PyTorch教程
https://www.bilibili.com/video/BV1hE411t7RN?spm_id_from=333.337.search-card.all.click

猜你喜欢

转载自blog.csdn.net/Dartao/article/details/124173708