[Pytorch framework] 2.1.4 Data loading and preprocessing

PyTorch basics: loading and preprocessing of data

PyTorch encapsulates commonly used data loading through torch.utils.data, which can easily realize multi-threaded data pre-reading and batch loading.
And torchvision has pre-implemented commonly used image data sets, including the previously used CIFAR-10, ImageNet, COCO, MNIST, LSUN and other data sets, which can be easily called through torchvision.datasets

# 首先要引入相关的包
import torch
#打印一下版本
torch.__version__
'1.0.1.post2'

Dataset

Dataset is an abstract class. In order to be easily read, the data to be used needs to be packaged as a Dataset class.
Custom Dataset needs to inherit it and implement two member methods:

  1. __getitem__()This method defines using index ( 0to len(self)) to get a piece of data or a sample
  2. __len__() This method returns the total length of the data set

Below we use a competition bluebook for bulldozers on kaggle to customize a data set. For the convenience of introduction, we use the data dictionary inside to illustrate (because the number of entries is small)

#引用
from torch.utils.data import Dataset
import pandas as pd
#定义一个数据集
class BulldozerDataset(Dataset):
    """ 数据集演示 """
    def __init__(self, csv_file):
        """实现初始化方法,在初始化的时候将数据读载入"""
        self.df=pd.read_csv(csv_file)
    def __len__(self):
        '''
        返回df的长度
        '''
        return len(self.df)
    def __getitem__(self, idx):
        '''
        根据 idx 返回一行数据
        '''
        return self.df.iloc[idx].SalePrice

At this point, our data set has been defined, and we can use an instance of an object to access it

ds_demo= BulldozerDataset('median_benchmark.csv')

We can directly use the following command to view the data set data

#实现了 __len__ 方法所以可以直接使用len获取数据总数
len(ds_demo)
11573
#用索引可以直接访问对应的数据,对应 __getitem__ 方法
ds_demo[0]
24000.0

The custom data set has been created. Below we use the official data loader to read the data

Dataloader

DataLoader provides us with the read operation of the Dataset. Common parameters are: batch_size (size of each batch), shuffle (whether to perform shuffle operation), num_workers (use several subprocesses when loading data). Do a simple operation below

dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)

What DataLoader returns is an iterable object, we can use iterator to get data in stages

idata=iter(dl)
print(next(idata))
tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)

The common usage is to use for loop to traverse it

for i, data in enumerate(dl):
    print(i,data)
    # 为了节约空间,这里只循环一遍
    break
0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)

We can already define data sets through dataset, and use Datalorder to load and traverse the data sets. In addition to these, PyTorch also provides a computer vision extension package capable of torcvision, which is encapsulated

torchvision bag

torchvision is a library dedicated to processing images in PyTorch. The last pip install torchvision in the installation tutorial on the PyTorch official website is to install this package.

torchvision.datasets

torchvision.datasets can be understood as the dataset customized by the PyTorch team. These datasets help us process a lot of image datasets in advance, and we can use them directly:

  • MNIST
  • COCO
  • Captions
  • Detection
  • LSUN
  • ImageFolder
  • Imagenet-12
  • CIFAR
  • STL10
  • SVHN

  • We can use PhotoTour directly, an example is as follows:
import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
                                      train=True,  # 表示是否加载数据库的训练集,false的时候加载测试集
                                      download=True, # 表示是否自动下载 MNIST 数据集
                                      transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理

torchvision.models

Torchvision not only provides commonly used image data sets, but also provides trained models, which can be used directly after loading, or the
following model structure is included in the sub-module of the torchvision.models module for migration learning .

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
#我们直接可以使用训练好的模型,当然这个与datasets相同,都是需要从服务器下载的
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)

torchvision.transforms

The transforms module provides general image transformation operation classes for data processing and data enhancement

from torchvision import transforms as transforms
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  #先四周填充0,在把图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
    transforms.RandomRotation((-45,45)), #随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])

Someone will definitely ask: (0.485, 0.456, 0.406), (0.2023, 0.1994, 0.2010) What do these numbers mean?

The official post has detailed instructions:
https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21
These are normalization parameters trained on ImageNet and can be used directly. We think this is a fixed value.

We have completed the introduction of the basic content of Python. Next, we will introduce the theoretical basis of neural networks. We use PyTorch to implement the formulas and so on.

Guess you like

Origin blog.csdn.net/yegeli/article/details/113686201