CIFAR10数据集取一张可视化保存

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zw__chen/article/details/82841449
transform = transforms.Compose([
        transforms.Resize(96),
        transforms.ToTensor()
        # transforms.Normalize((.5, .5, .5), (.5, .5, .5))
    ])

    test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=1,
                             shuffle=False)
    # 得到一个随机的训练图片
    dataiter = iter(test_loader)
    images, labels = dataiter.next()


    images = torch.squeeze(images) # (1,3,96,96) --> (3,96,96)
    images = torch.transpose(images, 0, -1) # (3,96,96) --> (96,96,3)
    img = images.numpy() # 将tensor转换为numpy
    img = img_as_ubyte(img) # 这点很重要!!这些数值都不是在0-255,所以要转换为unit8
    cv2.imwrite("./test.jpg", img) # 保存为test.jpg
    cv2.imshow(img) # 可视化
    cv2.waitKey(0)
    print "over."

猜你喜欢

转载自blog.csdn.net/zw__chen/article/details/82841449
今日推荐