Pytorch: torchvision.utils.make_grid函数的说明

Pytorch: torchvision.utils.make_grid函数的说明

网格化显示数据

# 环境准备
import numpy as np  # numpy数组库
import matplotlib.pyplot as plt  # 画图库
import torchvision.datasets as dataset  # 公开数据集的下载和管理
import torchvision.transforms as transforms  # 公开数据集的预处理库,格式转换
import torchvision
import torch.utils.data as data_utils  # 对数据集进行分批加载的工具集

# 2-1 准备数据集
train_data = dataset.MNIST(root="data",
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)

# 2-1 准备数据集
test_data = dataset.MNIST(root="data",
                          train=False,
                          transform=transforms.ToTensor(),
                          download=True)
# 批量数据读取
train_loader = data_utils.DataLoader(dataset=train_data,
                                     batch_size=64,
                                     shuffle=True)

test_loader = data_utils.DataLoader(dataset=test_data,
                                    batch_size=64,
                                    shuffle=True)


def imshow(img):
    # img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 将【c,h,w】-->【h,w,c】
    plt.show()


print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)

print("\n合并成一张三通道灰度图片")
images = torchvision.utils.make_grid(imgs, nrow=8, padding=0)
#保存图片
from torchvision.utils import save_image
save_image(images,'image.png')
#显示图片
print(images.shape)
imshow(images)

猜你喜欢

转载自blog.csdn.net/qq_40107571/article/details/130454734