沐神-MNist-Pycharm实现

cbjknm导入包并下载数据集

from matplotlib import pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms #对数据进行操作
from d2l import torch as d2l

d2l.use_svg_display() #用svg来显示图片更清晰。
#print(transforms)   #<module 'torchvision.transforms' from 'D:\\anaconda3\\envs\\pythonProject6\\lib\\site-packages\\torchvision\\transforms\\__init__.py'>
trans=transforms.ToTensor()
#print(trans)        #ToTensor()
#从torchvision.datasets里得到FashionMNist数据集,下载地址,对应下载的是训练数据集,需要得到的是pytorch的tensor,而不是一堆图片,默认从网上下载。
mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
#测试数据集
#print(len(mnist_train))
#print(len(mnist_test))
#print(mnist_train[0][0].shape)  #torch.Size([1,28,28])

print方法后面注释是输出内容

结果:

D:\anaconda3\envs\pythonProject6\python.exe C:/Users/Dell/PycharmProjects/pythonProject6/mnist.py
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz
100.0%
Extracting ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz
100.0%
Extracting ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz
100.0%
Extracting ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz
100.0%
Extracting ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw

60000
10000

进程已结束,退出代码0
定义两个函数来画出数据集
数据集有十个类别,关于衣服的数据集。
def get_fashion_mnist_labels(labels):
    """返回Fashion-MNist数据集的文本标签"""
    text_labels=[
        't-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot'
    ]
    return [text_labels[int (i)] for i in labels]

 python 循环:基础循环进阶循环

#把图片一个一个画出来
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """Plot a list of images."""
    #使用matplotlib来画出图片
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

此处引用这篇文章的代码,因为B站和书上的代码都有点问题。

B站的代码不显示标签。书上的代码会报错 ValueError:only one element tensors can be converted。所以上图代码增加了if判断语句,事关tensor和numpy之间的复杂的转换。

元组:元组和列表

构造了pytorch的数据集之后,放进一个DataLoader()里面,指定一个batch-size,得到一个大小为固定数据的数字。
python的iteration,用iter构造iteration,next就是拿到第一个小批量。即为一个X,y。
数据发生了变化,但是和之前的人工数据集的读数据基本上是一样的。
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) #next拿到第一批数据量
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))#shape是(18,28,28) 分成2行 9列
plt.show()   #或者d2l.plt.show()  pycharm必用

这样就可以输出了。

 而且这样就有标签了。

如果不乘scale,会得到下图。

 

 输出flatten前的axes和flatten后的axes:

前:

[[<AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>
  <AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>
  <AxesSubplot:>]
 [<AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>
  <AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>
  <AxesSubplot:>]]

后:

[<AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>
 <AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>
 <AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>
 <AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>
 <AxesSubplot:> <AxesSubplot:>]
axes.flatten():在用plt.subplots画多个子图中,ax = ax.flatten()将ax由n*m的Axes组展平成1*nm的Axes组。.flatten()

zip()函数用法:zip

一般来说数据放在硬盘上,可能需要多个进程来进行数据的读取、操作、预读取。根据CPU大小进行选择。
batch_size=256
def get_dataloader_wokers():
    """使用四个进程来读取数据"""
    return 4
train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_wokers())
timer=d2l.Timer()
for X,y in train_iter:
    continue
print(f'{timer.stop():.2f}')

.DataLoader()的使用:.DataLoader

print:5.63   (最快的一次了hehehehehe)

mnist.py文件中的所有代码:

from matplotlib import pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms #对数据进行操作
from d2l import torch as d2l

d2l.use_svg_display() #用svg来显示图片更清晰。
#print(transforms)   #<module 'torchvision.transforms' from 'D:\\anaconda3\\envs\\pythonProject6\\lib\\site-packages\\torchvision\\transforms\\__init__.py'>
trans=transforms.ToTensor()
#print(trans)        #ToTensor()
#从torchvision.datasets里得到FashionMNist数据集,下载地址,对应下载的是训练数据集,需要得到的是pytorch的tensor,而不是一堆图片,默认从网上下载。
mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
#测试数据集
#print(len(mnist_train))
#print(len(mnist_test))
#print(mnist_train[0][0].shape)  #torch.Size([1,28,28])

#定义两个函数来画出数据集
#数据集有十个类别,关于衣服的数据集。
def get_fashion_mnist_labels(labels):
    """返回Fashion-MNist数据集的文本标签"""
    text_labels=[
        't-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot'
    ]
    return [text_labels[int (i)] for i in labels]

#把图片一个一个画出来
#show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
 #imgs:shape是(18,28,28) 分成2行 9列
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """Plot a list of images."""
    #使用matplotlib来画出图片
    figsize = (num_cols * scale, num_rows * scale)
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

#构造了pytorch的数据集之后,放进一个DataLoader()里面,指定一个batch-size,得到一个大小为固定数据的数字。
#python的iteration,用iter构造iteration,next就是拿到第一个小批量。即为一个X,y。
#数据发生了变化,但是和之前的人工数据集的读数据基本上是一样的。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) #next拿到第一批数据量
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))#shape是(18,28,28) 分成2行 9列
#plt.show()   #或者d2l.plt.show()  pycharm必用

batch_size=256
#一般来说数据放在硬盘上,可能需要多个进程来进行数据的读取、操作、预读取。根据CPU大小进行选择。
def get_dataloader_wokers():
    """使用四个进程来读取数据"""
    return 3

train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_wokers())
#之前定义过Timer这个函数用来测试速度。
timer=d2l.Timer()
#用train_iter来一个一个访问所有的batch,整个扫一遍数据。
#读一次数据是5.63s
for X,y in train_iter:
    continue
print(f'{timer.stop():.2f}')

猜你喜欢

转载自blog.csdn.net/qq_45828494/article/details/126356547