【PyTorch】:数据预处理

 
 
# Pytorch 0.4.0 cifar10数据集显示
# @Time: 2018/6/15
# @Author: xfLi

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


def getData():
    #数据预处理
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    #训练集
    train_set = tv.datasets.CIFAR10(root='/data/', train=True, transform=transform, download=True)
    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    #测试集
    test_set = tv.datasets.CIFAR10(root='/data/', train=False, transform=transform, download=True)

    classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
    return train_loader, test_set, classes

if __name__ == '__main__':
    _, testset, classes = getData()
    for img, label in testset:
        print(classes[label])
        img = img / 2 + 0.5   #  unnormalize  
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

猜你喜欢

转载自blog.csdn.net/qq_30159015/article/details/80756470