pytorch -- Detailed Explanation of Dataset and DataLoader for Data Loading

I have nothing to do, and I really don’t want to do scientific research, so I wrote this article:

I believe that many friends are like me. When I first started pytorch, I have almost mastered the basic pytorch training process. I have also learned how to read data, how to build a network, how to train, etc. through some B station tutorials. A series of operations: For students who have no foundation in this area, I recommend what I learned, the pytorch introductory video of Mr. Liu Er at station b

After watching this video, you should be an entry-level user of pytorch

I finished the class before, and I also got the code out by the way, so I would like to share with you:

http://t.csdn.cn/xZ8Gx

But here comes the problem. When I just watched the video, when I didn’t have any reference materials, when I wanted to start writing a network model with pytorch, I was a little confused and didn’t know how to start. So, I want to pass this This blog, while sorting out the pytorch training process, also brings a little summary to the brothers

First of all: use a two-category problem of cats and dogs to demonstrate to everyone

Remarks: By default, everyone has configured the environment.

The first thing is to look for samples. You can find similar data sets on the Internet or some open source data sets.

Here are some for you:

http://academictorrents.com

https://github.com/awesomedata/awesome-public-datasets

https://blog.csdn.net/u012735708/article/details/82682673

https://www.cnblogs.com/ansang/p/8137413.html

http://vision.stanford.edu/resources_links.html

http://slazebni.cs.illinois.edu

After the data is collected

 Put the pictures into the following folders respectively, be careful not to mix them up, otherwise it will have a great impact on the training of the network

Next is the import data process of pytorch

split_dir = os.path.join('trains', 'data')
train_dir = os.path.join(split_dir, 'Dog')
valid_dir = os.path.join(split_dir, 'Cat')

The os.path.join function in the os library is used here, and the input is the path of the folder

Then there is the Dataset setting in pytorch: at the beginning, you need to define this Dataset class

class RNMataset(Dataset):
    def __init__(self,data_dir,transform = None):
        self.label_name = {"Cat":0,"Dog":1}
        self.data_info = self.get_image_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img,label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为tensor等等

        return img,label
    def __len__(self):
        return len(self.data_info)
    def get_image_info(self,data_dir):
        data_info = list()
        for root , dirs,_ in os.walk(data_dir):
            for sub_dirs in dirs :
                img_names = os.listdir(os.path.join(root, sub_dirs))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dirs, img_name)
                    label = self.label_name[sub_dirs]
                    data_info.append((path_img, int(label)))
        return data_info

You can take a closer look, there are mainly three functions, and define the representative numbers of cat and dog:

class RNMataset(Dataset):
    def __init__(self,data_dir,transform = None):
        self.label_name = {"Cat":0,"Dog":1}
        self.data_info = self.get_image_info(data_dir)
        self.transform = transform

    def __getitem__(self, index)
    ##主要获取摸个图像的索引以及文件

    def __len__(self):
    ##主要获取输入文件的个数
        
    def get_image_info(self,data_dir):
    ##主要将文件的索引和文件名放入一个列表中返回
      

You can read my notes to get a general understanding of the meaning of each function. Of course, after two days after the package is packaged, I forget what it is used for hahahahahahaha 

It should be noted that the Dataset class I encapsulated is only applicable to:

1. The same type is in one folder

2. The name of the folder is the name of the class

Well, so far, this basically defines a Dataset class

The rest is as simple as:

Generally, in the training process, the data enhancement for the data set is generally defined before the class is defined, because when instantiating, it needs to use:

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),

])
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

The above is the data enhancement part:

The data enhancement of the training set and the test set are different. You can take a look at the specific code:

There are also many other data enhancement methods, you can refer to the official documents to view

By the way,

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

What are these two lines for?

transforms.Normalize(norm_mean, norm_std)

As you can see, these two values ​​are used on the Normalize side, mainly to normalize the image to facilitate network optimization. Simply put, this processing makes the network easier to fit

Back to topic:

After the data enhancement part is completed, it is time to start data instantiation:

train_data = RMBDataset(data_dir=train_dir, transform = train_transform)
val_data = RMBDataset(data_dir=valid_dir, transform = valid_transform)

You can pay attention to the parameters: one is the path of the file, and the other is the part of data enhancement

So far, our Dataset is finished,

So now that the data has been imported, why is DataLoader used?

Just look at the name: data loading, the Datase class loads all the data in, we don’t feed all the data to the network at one time

Rather: one Epoch, one full data, how many times is that one Epoch divided into?

This depends on how big the Batch_size is. There are a total of 100 added data, and the Batch_size is 10. That Epoch is divided into ten input data

So DataLoader is actually inputting data into the network in batches for training.

train_loader = DataLoader(dataset=train_data,batch_size=Batch_size,shuffle=True)
val_loader = DataLoader(dataset=val_data,batch_size=Batch_size,shuffle=False)
What is the parameter shuffle for? Whether to shuffle the input data each time, generally in the training set, to enhance the generalization ability

The verification set will not be disturbed

So far, Dataset and DataLoader are finished

Finally, attach all the codes for everyone to copy:

import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from torchvision import transforms

os.environ["CUDA_VISIBLE_DEVICES"] = '0'



###数据读取
split_dir = os.path.join('线性回归', '阿里天池大赛蒸汽预测')
train_dir = os.path.join(split_dir, 'zhengqi_test.txt')
valid_dir = os.path.join(split_dir, 'zhengqi_train.txt')

###数据增强,翻转,裁剪等
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),

])
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
##定义一个datast的类
class RMBDataset(Dataset):
    def __init__(self,data_dir,transform = None):
        self.label_name = {"1":0,"100":1}
        self.data_info = self.get_image_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img,label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为tensor等等

        return img,label
    def __len__(self):
        return len(self.data_info)
    def get_image_info(self,data_dir):
        data_info = list()
        for root , dirs,_ in os.walk(data_dir):
            for sub_dirs in dirs :
                img_names = os.listdir(os.path.join(root, sub_dirs))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dirs, img_name)
                    label = self.label_name[sub_dirs]
                    data_info.append((path_img, int(label)))
        return data_info
Batch_size = 0
train_data = RMBDataset(data_dir=train_dir, transform = train_transform)
val_data = RMBDataset(data_dir=valid_dir, transform = valid_transform)

train_loader = DataLoader(dataset=train_data,batch_size=Batch_size,shuffle=True)
val_loader = DataLoader(dataset=val_data,batch_size=Batch_size,shuffle=False)

Guess you like

Origin blog.csdn.net/weixin_53374931/article/details/130091654