版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zw__chen/article/details/82841449
transform = transforms.Compose([
transforms.Resize(96),
transforms.ToTensor()
# transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=1,
shuffle=False)
# 得到一个随机的训练图片
dataiter = iter(test_loader)
images, labels = dataiter.next()
images = torch.squeeze(images) # (1,3,96,96) --> (3,96,96)
images = torch.transpose(images, 0, -1) # (3,96,96) --> (96,96,3)
img = images.numpy() # 将tensor转换为numpy
img = img_as_ubyte(img) # 这点很重要!!这些数值都不是在0-255,所以要转换为unit8
cv2.imwrite("./test.jpg", img) # 保存为test.jpg
cv2.imshow(img) # 可视化
cv2.waitKey(0)
print "over."