PyTorch study notes - common functions and data loading

1. PyTorch common functions

(1) Path-related functions
Suppose the directory structure of our data set is as follows:

insert image description here

First of all import os, the commonly used path-related functions osin are:

  • os.listdir(path): List the contents of paththe directory into one list.
  • os.path.join(path1, path2):stitch path: path1\path2.

For example:

import os

dir_path = 'dataset/hymenoptera_data/train/ants_image'
img_path_list = os.listdir(dir_path)
img_full_path = os.path.join(dir_path, img_path_list[0])
print(img_path_list)  # ['0013035.jpg', '1030023514_aad5c608f9.jpg', ...]
print(img_full_path)  # dataset/hymenoptera_data/train/ants_image\0013035.jpg

(2) Auxiliary functions

  • dir(): When there is no parameter, return the variable, method and defined type list in the current scope; when there is a parameter, return the attribute and method list of the parameter.
  • help(func): View the usage instructions funcof .

For example:

import torch

print(dir(torch))  # ['AVG', 'AggregationType', ..., 'cuda', ...]
help(torch.cuda.is_available)  # Help on function is_available in module torch.cuda: is_available() -> bool...

2. Data loading

Data reading and preprocessing of PyTorch dataset (Dataset) is the primary operation for machine learning. PyTorch provides many methods to complete data reading and preprocessing.

(1) Dataset: torch.utils.data.DatasetIt is an abstract class representing this data. You can define your own data class, inherit and rewrite this abstract class, it is very simple, you only need to define __len__and __getitem__this two functions, for example:

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir + '_image')
        self.img_path_list = os.listdir(self.path)

    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]
        img_full_path = os.path.join(self.root_dir, self.label_dir + '_image', img_path)
        img = Image.open(img_full_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path_list)

root_dir = 'dataset/hymenoptera_data/train'
ants_label_dir = 'ants'

ants_data = MyData(root_dir, ants_label_dir)
img, label = ants_data[0]
print(img, label)
img.show()

Through the above method, we can define the data class we need, and obtain each data through iteration, but it is difficult to achieve batch, shuffle or multi-threaded reading of data.

(2) DataLoader: torch.utils.data.DataLoaderBuild an iterable data loader. When we are training, each for loop and each iteration obtains a batch_sizesize . For example, if Dataset is a complete set of playing cards, then DataLoader is to draw a part of playing cards composed of several pieces.

Before learning DataLoader, you need to learn about Transform: PyTorch Study Notes-Transform .

There are many parameters of DataLoader, but the main ones we commonly use are as follows:

  • dataset: Dataset class, which determines where and how to read data.
  • batch_size: batch size.
  • num_works: Whether multi-process read mechanism.
  • shuffle: Whether each Epoch is out of order.
  • drop_last: Whether to discard the last batch of data when the number of samples cannot be batch_sizedivisible by .

To understand this drop_last, first, you must first understand the concepts of Epoch, Iteration and Batch_size:

  • Epoch: All training samples have been input into the model, called an Epoch.
  • Iteration: A batch of samples is input into the model, which is called an Iteration.
  • Batch_size: The size of a batch of samples determines how many iterations an Epoch has.

The role of DataLoader is to build a data loader. According to the size of the batch_size we provide, the data samples are divided into batches to train the model. In the process of dividing, the data needs to be fetched. This is the method Datasetof __getitem__.

For example:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self, root_dir, label_dir, transform):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir + '_image')
        self.img_path_list = os.listdir(self.path)
        self.transform = transform  # transform 的方式

    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]
        img_full_path = os.path.join(self.root_dir, self.label_dir + '_image', img_path)
        img = Image.open(img_full_path).convert('RGB')  # 先将图片转换成三通道
        if self.transform is not None:
            img = self.transform(img)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path_list)

root_dir = 'dataset/hymenoptera_data/train'
ants_label_dir = 'ants'

trans_dataset = transforms.Compose([
    transforms.Resize((83, 100)),  # tensor 大小必须统一
    transforms.ToTensor()
])

ants_data = MyData(root_dir, ants_label_dir, trans_dataset)

train_loader = DataLoader(dataset=ants_data, batch_size=10, shuffle=True, num_workers=0, drop_last=False)

for i, data in enumerate(train_loader):
    img, label = data
    print(type(img))
    print(img[0])
    print(label)
    print(label[0])

The following is the information of a batch. img contains batch_size images, and label contains batch_size labels corresponding to each image:

insert image description here

Next, use the CIFAR10 data set to show the usage of DataLoader again. The use of the data set can be seen in: PyTorch study notes - Torchvision data set usage method :

from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_set = datasets.CIFAR10(root='dataset/CIFAR10', train=False, transform=transforms.ToTensor())

test_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

writer = SummaryWriter('logs')

for epoch in range(2):  # 循环两个 epoch
    for step, data in enumerate(test_loader):  # step 表示第几个 batch
        imgs, targets = data
        writer.add_images('Epoch_{}'.format(epoch), imgs, step)  # 注意是 add_images,图像默认格式为 NCHW

writer.close()

The above code loops through two epochs, that is, reads all the data sets in two rounds. Each epoch uses a for loop to read each batch in the DataLoader, and each batch is used as a step to add the image to TensorBoard.

Note that the function used to add an image collection here is add_imagesnot add_image, add_imagesthe default size format of the incoming image collection is NCHW: [batch_size, channel, height, width], and add_imagethe default size format of the incoming image is CHW. Open the TensorBoard website to see the following results:

insert image description here

Guess you like

Origin blog.csdn.net/m0_51755720/article/details/128051831