导入包
import torchvision
import torch
from matplotlib import pyplot as plt
from IPython import display
数据导入
def load_data_fashion_mnist(batch_size, resize=None, root="/home/kesci/input/FashionMNIST2065/"):
#图像高和宽扩大到224通过resize实现
"""Download the fashion mnist dataset and then load into memory."""
trans = []
if resize:
trans.append(torchvision.transforms.Resize(size=resize))
trans.append(torchvision.transforms.ToTensor())
#transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片或者numpy数组
transform = torchvision.transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
#mnist_train是torch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
#DataLoader中一个很方便的功能是允许使用多进程来加速读取。可以通过参数num_workers来设置进程
return train_iter, test_iter
画出多张图
def show_fashion_mnist(images, labels):
"""Use svg format to display plot in jupyter"""
display.set_matplotlib_formats('svg')
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((224, 224)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
加载数据
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
train_data = iter(train_iter)
images, labels = next(train_data)
labels = get_fashion_mnist_labels(labels)
show_fashion_mnist(images[:10], labels[:10])
plt.show()