读取和归一化CIFAR10:
torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
参数说明:
-
root:cifar-10-batches-py 的根目录
-
train:True = 训练集 ,False = 测试集
-
download : True = 从互联网上下载数据,并将其放在root目录下;False=数据已经下载,则什么都不干。
-
transform:接受PIL图像的函数/变换,并返回转换后的版本。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
参数说明:
- dataset:加载数据的数据集
- batch_size:每个batch加载多少个样本(默认: 1)。
- shuffle:设置为True时会在每个epoch重新打乱数据(默认: False).
- sampler:定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
- num_workers:用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
- drop_last (bool, optional) :如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
transforms.ToTensor:不明
transforms.Normalize:不明
程序:
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
#测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
总结:使用torchvision可以非常容易地加载CIFAR10,CIFAR有两个数据集,一个是训练集,另外一个是测试集。首先先下载训练集,torchvision的输出是[0,1]的PILImage图像,将她转换为归一化范围为[-1,1]的张量。加载下载好的训练集,每个块加载4个通道,将每个epoch重新打乱数据,用2个子进程加载数据。
接着下载测试集,torchvision的输出是[0,1]的PILImage图像,将她转换为归一化范围为[-1,1]的张量。加载下载好的测试集,每个块加载4个通道,每个epoch默认原始数据,用2个子进程加载数据。
则完成读取和归一化CIFAR10数据。