读取和归一化CIFAR10

读取和归一化CIFAR10:

torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)

参数说明:

  1. root:cifar-10-batches-py 的根目录

  2. train:True = 训练集 ,False = 测试集

  3. download : True = 从互联网上下载数据,并将其放在root目录下;False=数据已经下载,则什么都不干。

  4. transform:接受PIL图像的函数/变换,并返回转换后的版本。

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
    

参数说明:

  1. dataset:加载数据的数据集
  2. batch_size:每个batch加载多少个样本(默认: 1)。
  3. shuffle:设置为True时会在每个epoch重新打乱数据(默认: False).
  4. sampler:定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  5. num_workers:用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  6. drop_last (bool, optional) :如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
    transforms.ToTensor:不明
    transforms.Normalize:不明
    程序:

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

#测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

总结:使用torchvision可以非常容易地加载CIFAR10,CIFAR有两个数据集,一个是训练集,另外一个是测试集。首先先下载训练集,torchvision的输出是[0,1]的PILImage图像,将她转换为归一化范围为[-1,1]的张量。加载下载好的训练集,每个块加载4个通道,将每个epoch重新打乱数据,用2个子进程加载数据。
接着下载测试集,torchvision的输出是[0,1]的PILImage图像,将她转换为归一化范围为[-1,1]的张量。加载下载好的测试集,每个块加载4个通道,每个epoch默认原始数据,用2个子进程加载数据。
则完成读取和归一化CIFAR10数据。

猜你喜欢

转载自blog.csdn.net/qq_31244453/article/details/104701064
今日推荐