PyTorch basics (four) ----- data loading and preprocessing

Preface

I have briefly talked about PyTorch's Tensor, Autograd, torch.nn and torch.optim packages. Through these, we can easily build a network model, but this is not enough. We also need a lot of data. As we all know, data is depth. The soul of learning, deep learning models are "fed" by data. In this article, we will talk about data loading and preprocessing.

  • First, we have to introduce the torch package
import torch
torch.__version__

1. Data loading

PyTorch encapsulates commonly used data loading through torch.utils.data, which can easily realize multi-threaded data pre-reading and batch loading.

1.1 Dataset

Dataset is an abstract class. In order to be easily read, the data to be used needs to be packaged as a Dataset class. The custom Dataset class needs to inherit it and implement 2 member methods:

  • 1.__getitem__(): This method defines to get a piece of data or a sample with index (0-len(self))
  • 2.__len__(): This method returns the total length of the data set

Below we use a competition bluebook for bulldozers on Kaggle to customize a data set. For the convenience of introduction, we use the data dictionary inside to illustrate

  • First, we need to reference related packages
from torch.utils.data import Dataset
import pandas as pd
  • Customize a data set
#定义一个数据集
class BulldozerDataset(Dataset):
    """ 数据集演示 """
    def __init__(self, csv_file):
        """实现初始化方法,在初始化的时候将数据读载入"""
        self.df=pd.read_csv(csv_file)
    def __len__(self):
        '''
        返回df的长度
        '''
        return len(self.df)
    def __getitem__(self, idx):
        '''
        根据 idx 返回一行数据
        '''
        return self.df.iloc[idx].SalePrice
  • At this point, our data set has been defined, we can instantiate an object to access
ds_demo= BulldozerDataset('median_benchmark.csv')
  • We can directly use the following command to view the data set data
# 前面我们已经实现了__len__方法,所以可以直接使用
len(ds_demo)
  • Use the index to directly access the corresponding data
ds_demo[0]

The custom data set has been created. Below we use the official data loader to read the data

1.2 DataLoader

DataLoader provides us with the read operation of the Dataset. Common parameters are: batch_size (the size of each batch), shuffle (whether to perform shuffle operation), num_workers (use several subprocesses when loading data). Here is a simple demonstration:

dl = torch.utils.data.DataLoader(ds_demo,batch_size = 10,shuffle = True,num_workers = 0)

What DataLoader returns is an iterable object, we can use iterator to get data in stages

idata=iter(dl)
print(next(idata))

The common usage is to use for loop to traverse it

for i, data in enumerate(dl):
    print(i,data)
    # 为了节约空间,这里只循环一遍
    break

At this point, we can define the data set through dataset, and use DataLorder to load and traverse the data set.

Two, torchvision package

torchvision is a library dedicated to processing images in PyTorch. The last pip install torchvision in the installation tutorial on the PyTorch official website is to install this package.
Torchvision has implemented common image data sets in advance, including the previously used CIFAR-10, ImageNet, COCO, MNIST, LSUN and other data sets, which can be conveniently called through torchvision.datasets.

  • Here is a summary of the data sets that torchvision has been pre-installed:
Data set name
MNIST
COCO
CIFAR-10
ImageNet
Captions
Detection
LSUN
ImageFolder
Imagenet-12
STL10
SVHN
PhotoTour

The data set that comes with PyTorch is provided by two upper-level APIs, namely torchvision and torchtext

  • Torchvision provides relevant data and api for image data processing
    • Data location: torchvision.datasets ; for example: torchvision.datasets.MNIST
  • Torchtext provides relevant data and api for text data processing
    • Data location: torchtext.datasets; for example: torchtext.datasets.IMDB

Let's make a simple demonstration

  • First, we have to introduce the torchvision package
import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
                                      train=True,  # 表示是否加载数据库的训练集,false的时候加载测试集
                                      download=True, # 表示是否自动下载 MNIST 数据集
                                      transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理

2.1 torchvision.models

Torchvision not only provides commonly used image data sets, but also provides some trained network models, which can be used directly after loading, or continue to transfer learning. The sub-modules of the torchvision.models module contain the following models:

Network model
AlexNet
VGG
ResNet
SqueezeNet
DenseNet

We can directly use the trained model. Of course, this is the same as datasets and needs to be downloaded from the server.

  • First, we need to import torchvision.models
import torchvision.models as models
  • Use directly
resnet18 = models.resnet18(pretrained=True)

2.2 torchvision.tranforms

The transforms module provides general image transformation operation classes for data processing and data enhancement

  • First, we need to introduce torchvision.tranforms, and then do a simple demonstration
from torchvision import transforms as transforms
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  #先四周填充0,在把图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
    transforms.RandomRotation((-45,45)), #随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])

Someone will definitely ask: (0.485, 0.456, 0.406), (0.2023, 0.1994, 0.2010) What do these numbers mean?
The official post has detailed instructions: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21 These are normalization parameters trained on ImageNet and can be used directly. We think this is a fixed value.
At this point, we have completed the basic content introduction of PyTorch.

references

https://github.com/zergtant/pytorch-handbook/blob/master/chapter2

Guess you like

Origin blog.csdn.net/dongjinkun/article/details/113869697