2. Pytorch learning: importing data sets by yourself and using the public data sets that come with it

Import the dataset yourself

Ant and bee classification dataset download link:

https://download.pytorch.org/tutorial/hymenoptera_data.zip

from torch.utils.data import Dataset
import os
from PIL import Image


class MyData(Dataset):
    # root_dir数据集根目录文件夹,label_dir为标注过的图片文件夹
    # 初始化+读取数据集
    def __init__(self, root_dir, label_dir):
        # root_dir数据集根目录文件夹
        self.root_dir = root_dir
        # label_dir为标注过的图片文件夹
        self.label_dir = label_dir
        # 使用os.path.join()函数,拼接路径,因为win是\\拼接,linux是\拼接。
        self.path = os.path.join(self.root_dir, self.label_dir)
        # 将路径下的文件存成数组(array)的形式。数组的元素对应每个图片的名字(str字符串类型)。
        self.img_path = os.listdir(self.path)

    # 对于指定的idx(索引,因为img_path是一个由图片名(字符串)按照一定顺序组成的数组),获取数据并返回。
    def __getitem__(self, idx):
        # 通过idx索引图片名。
        img_name = self.img_path[idx]
        # 拼接数据集根目录、标注图片文件夹目录与图片名。
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        # 利用PIL(Python Imaging Library)中的Image中的open()函数打开图片。
        img = Image.open(img_item_path)
        # 因为这是分类任务,故标注图片的文件夹名就是种类的名字。
        label = self.label_dir
        # 返回图片与
        return img, label

    # 返回图片的个数(图片文件名列表的长度)。
    def __len__(self):
        return len(self.img_path)


# 数据集根目录。
dataset_root_dir = "hymenoptera_data/hymenoptera_data/train"
# 数据集的标注文件夹,因为是蚂蚁的蜜蜂的分类问题。
ants_label_dir = "ants"
bees_label_dir = "bees"
# def __init__(self, root_dir, label_dir): ,方法__init__有两个形参,根据这个类创建实例就必须要指定形参的值。
ants_dataset = MyData(dataset_root_dir, ants_label_dir)
bees_dataset = MyData(dataset_root_dir, ants_label_dir)
# 将两个数据集加载一块组成训练集
train_dataset = ants_dataset + bees_dataset

Note 1:

The method __init__ has two formal parameters, and the value of the formal parameters must be specified to create an instance according to this class. The purpose of this is to pass the two paths through the os.path.join() function to form the desired path and find the dataset picture. At the same time, because it is a classification problem, the category is the file name, so the second formal parameter is the category name.

def __init__(self, root_dir, label_dir):

Note 2:

Thanks to this method, the images in the dataset can be indexed by list indexing. Provide the basis for future operations.

def __getitem__(self, idx):

Note 3:

Thanks to this method, the length of the data set list, that is, the number of pictures, can be queried through the len() function.

def __len__(self):

Leverage your own public dataset

import torchvision
from matplotlib import pyplot as plt
from torchvision import transforms
import torch
from torch.utils.data import DataLoader

# transforms模块,用到了ToTensor()与Normalize(),前者是为了将图片转为为tensor张量,转化后才能被神经网络接收,
# 后者是为了使每个信道(BGR)的灰度值平均值为0,标准差为1,数学上的标准差概念是各个元素的值距离平均值的距离的平均值为标准差。
# 故此操作为归一化,许多教程这里的标准差和平均值都是瞎给的,没有跟数据集对应.
# Compose能将一系列的预处理以列表的方式一一给出.
dataset_transforms = transforms.Compose([torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize((0.4915, 0.4823, 0.4468),
                                                                          (0.2470, 0.2435, 0.2616))
                                                     ])

# 加载已有的数据集
# root为下载的路径, ./的意思是当前路径下, ../意思是上一级路径
# train布尔值,True为训练集,False为验证集
# transform预处理模块.
train_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True, transform=dataset_transforms)
val_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True, transform=dataset_transforms)

# 打印出训练集是一个类class.
print(type(train_dataset))
# 由于类中有__len__()函数,所以可以查询训练集列表的长度.
print("训练集长度:{}".format(len(train_dataset)))
# 由于类中有__getitem__()函数,所以可以使用标准索引对元组和列表进行索引,以访问单个数据项.
img, label = train_dataset[99]
print("图片:{}\n图片类型:{}\n图片形状:{}\n图片数据类型:{}\n值的范围:{}到{}\n标注:{}\n图片种类名:{}"
      .format(img, type(img), img.shape, img.dtype, img.min(), img.max(), label, train_dataset.classes[label]))

# 打开图片,如果用到了transform中的ToTensor,图片是打不开的
# AttributeError: 'Tensor' object has no attribute 'show'
# img.show()

# C×H×W改为H×W×C
plt.imshow(img.permute(1, 2, 0))
plt.show()

# 此处能知道每个信道的平均值与标准差,没进行预处理中的归一化操作,平均值和标准差是不稳定的.
imgs = torch.stack([img for img, _ in train_dataset], dim=3)
print("拼接后的形状:{}".format(imgs.shape))
print("每个信道平均值:{}".format(imgs.view(3, -1).mean(dim=1)))
print("每个信道标准差:{}".format(imgs.view(3, -1).std(dim=1)))

# Dataloader
# 数据集,批处理大小,是否随机抓取(不放回抓取),并行处理,最后剩下不满足批处理大小的数据是否扔掉.
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
# CIFAR10类中有__getitem__()函数.
for data in train_loader:
    imgss, labels = data
    print(imgss.shape)
    print(labels)

Note 1:

ToTensor is essential. The normalized parameters are not randomly given. When using other people's code to train your own custom data set, the mean and standard deviation need to be recalculated.

# transforms模块,用到了ToTensor()与Normalize(),前者是为了将图片转为为tensor张量,转化后才能被神经网络接收,
# 后者是为了使每个信道(BGR)的灰度值平均值为0,标准差为1,数学上的标准差概念是各个元素的值距离平均值的距离的平均值为标准差。
# 故此操作为归一化,许多教程这里的标准差和平均值都是瞎给的,没有跟数据集对应.
# Compose能将一系列的预处理以列表的方式一一给出.
dataset_transforms = transforms.Compose([torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize((0.4915, 0.4823, 0.4468),
                                                                          (0.2470, 0.2435, 0.2616))
                                                     ])

Note 2:

To output pictures through matplotlib, the tensor tensor must be reordered.

# C×H×W改为H×W×C
plt.imshow(img.permute(1, 2, 0))
plt.show()

Note 3:

If you don’t know the specific value of some dimension, assign it as -1, and the system will automatically calculate it.

# 此处能知道每个信道的平均值与标准差,没进行预处理中的归一化操作,平均值和标准差是不稳定的.
imgs = torch.stack([img for img, _ in train_dataset], dim=3)
print("拼接后的形状:{}".format(imgs.shape))
print("每个信道平均值:{}".format(imgs.view(3, -1).mean(dim=1)))
print("每个信道标准差:{}".format(imgs.view(3, -1).std(dim=1)))

operation result

Because there are 50,000 pictures, the final result of the for loop is only partial:

Files already downloaded and verified
Files already downloaded and verified
<class 'torchvision.datasets.cifar.CIFAR10'>
训练集长度:50000
图片:tensor([[[-1.0055, -1.1960, -1.2595,  ...,  0.6615,  0.9156,  0.1852],
         [-0.9896, -1.1167, -1.1643,  ...,  0.5980,  0.7251,  0.3123],
         [-1.0690, -0.9738, -1.1008,  ...,  0.4393,  0.3916, -0.0370],
         ...,
         [ 0.7409,  0.2805,  0.0741,  ..., -0.4975,  0.2487,  0.2170],
         [ 0.9156,  0.3916, -0.7197,  ..., -0.7039,  0.1535,  0.2805],
         [ 1.3284,  0.8997,  0.2170,  ..., -1.0531,  0.0741,  0.6933]],

        [[-0.9500, -1.1754, -1.2721,  ...,  0.7894,  0.9826,  0.2096],
         [-0.9339, -1.1271, -1.1754,  ...,  0.7410,  0.8216,  0.3706],
         [-0.9822, -0.9178, -1.0144,  ...,  0.5156,  0.4995,  0.0807],
         ...,
         [ 0.1935, -0.2091, -1.0788,  ..., -0.7728, -0.2414, -0.2897],
         [ 0.3706, -0.0803, -0.9500,  ..., -0.8211, -0.0803,  0.0324],
         [ 0.8216,  0.4512, -0.2253,  ..., -1.1110, -0.0642,  0.5317]],

        [[-1.0484, -1.3182, -1.4231,  ..., -0.6736, -0.5687, -0.6286],
         [-1.1533, -1.3182, -1.3032,  ..., -0.7935, -0.5836, -0.5537],
         [-1.1683, -1.1533, -1.1533,  ..., -0.7785, -0.7485, -0.8535],
         ...,
         [-0.2239, -0.4487, -1.0783,  ..., -0.8685, -0.4188, -0.4937],
         [ 0.0460, -0.2838, -1.0484,  ..., -0.8085, -0.2389, -0.0590],
         [ 0.4507,  0.1359, -0.4637,  ..., -1.0034, -0.0440,  0.6906]]])
图片类型:<class 'torch.Tensor'>
图片形状:torch.Size([3, 32, 32])
图片数据类型:torch.float32
值的范围:-1.9806982278823853到2.126077890396118
标注:1
图片种类名:automobile
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
拼接后的形状:torch.Size([3, 32, 32, 50000])
每个信道平均值:tensor([-0.0004, -0.0006, -0.0010])
每个信道标准差:tensor([1.0001, 0.9999, 1.0000])
torch.Size([64, 3, 32, 32])
tensor([5, 7, 6, 0, 5, 2, 5, 8, 0, 3, 2, 8, 1, 0, 4, 3, 1, 5, 3, 2, 4, 2, 5, 4,
        8, 9, 2, 6, 9, 4, 6, 9, 7, 6, 2, 6, 3, 4, 7, 0, 8, 7, 0, 1, 4, 7, 0, 5,
        9, 3, 7, 7, 8, 5, 3, 3, 0, 4, 7, 7, 4, 5, 5, 1])
torch.Size([64, 3, 32, 32])
tensor([4, 6, 6, 9, 2, 8, 8, 5, 3, 0, 6, 5, 0, 4, 6, 1, 2, 7, 3, 2, 0, 8, 3, 1,
        5, 4, 3, 9, 6, 6, 7, 3, 7, 3, 3, 2, 4, 6, 3, 8, 6, 8, 2, 4, 0, 9, 4, 8,
        8, 8, 8, 8, 4, 8, 6, 3, 2, 0, 5, 3, 9, 3, 4, 4])
torch.Size([64, 3, 32, 32])
tensor([3, 4, 1, 8, 6, 0, 1, 3, 5, 0, 3, 9, 3, 5, 0, 6, 0, 7, 2, 8, 6, 9, 0, 4,
        5, 3, 7, 1, 9, 2, 4, 2, 9, 3, 5, 6, 0, 6, 0, 9, 1, 1, 8, 4, 1, 2, 9, 3,
        3, 9, 1, 1, 3, 8, 6, 0, 4, 0, 8, 0, 3, 7, 2, 5])
torch.Size([64, 3, 32, 32])

————————————————————————————————————此处省略————————————————————————————————————————————

torch.Size([16, 3, 32, 32])
tensor([3, 8, 1, 4, 5, 1, 5, 2, 4, 1, 3, 2, 3, 3, 9, 6])

Process finished with exit code 0

Points to note : The tail of less than 64 pictures is left at the end because the drop_last of DataLoader is False.

torch.Size([16, 3, 32, 32])
tensor([3, 8, 1, 4, 5, 1, 5, 2, 4, 1, 3, 2, 3, 3, 9, 6])

Process finished with exit code 0

Guess you like

Origin blog.csdn.net/wzfafabga/article/details/127694889
Recommended