pytorch --数据加载之 Dataset 与DataLoader详解

闲来无事啊,也实在是不想科研了,就写一下这篇文章:

相信很多小伙伴和我一样啊,在刚开始入门pytorch的时候,对于基本的pytorch训练流程已经掌握差不多了,也已经通过一些b站教程什么学会了怎么读取数据,怎么搭建网络,怎么训练等一系列操作了:还没有这方面基础的同学们,我推荐一下我学的,b站的刘二大人的pytorch入门视频

看完这个视频,你应该就算是对于pytorch的使用达到入门级了

我之前上完课,也顺便把代码搞出来了,给大家分享一下:

http://t.csdn.cn/xZ8Gx

但是问题来了,刚看完视频的时候,我没有任何参考资料的时候,我想开始自己用pytorch写一个网络模型的时候,脑子里有点懵,不知道该如何下手,所以呢,想通过这篇博客,梳理pytorch训练流程的同时,也给兄弟们带来一点总结

首先呢:用一个猫和狗的二分类问题来给大家做演示

备注::我默认大家已经配置好环境了哈

首先就是寻找样本嘛,随便去网上或者一些开源的数据集都可以找到类似的数据集

在这里给大家列举一些:

http://academictorrents.com

https://github.com/awesomedata/awesome-public-datasets

https://blog.csdn.net/u012735708/article/details/82682673

https://www.cnblogs.com/ansang/p/8137413.html

http://vision.stanford.edu/resources_links.html

http://slazebni.cs.illinois.edu

数据收集好了之后

 分别将图片放入下面的文件夹中,注意别放混了,不然对于网络的训练会有很大影响

接下来就是pytorch的导入数据流程了

split_dir = os.path.join('trains', 'data')
train_dir = os.path.join(split_dir, 'Dog')
valid_dir = os.path.join(split_dir, 'Cat')

这里用的是os库中的os.path.join函数,输入的就是文件夹的路径

然后就是pytorch中的Dataset设置:刚开始呢,都需要去定义这一个Dataset类

class RNMataset(Dataset):
    def __init__(self,data_dir,transform = None):
        self.label_name = {"Cat":0,"Dog":1}
        self.data_info = self.get_image_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img,label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为tensor等等

        return img,label
    def __len__(self):
        return len(self.data_info)
    def get_image_info(self,data_dir):
        data_info = list()
        for root , dirs,_ in os.walk(data_dir):
            for sub_dirs in dirs :
                img_names = os.listdir(os.path.join(root, sub_dirs))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dirs, img_name)
                    label = self.label_name[sub_dirs]
                    data_info.append((path_img, int(label)))
        return data_info

可以细看一下,主要就三个函数,并且定义cat 和dog 的代表数字:

class RNMataset(Dataset):
    def __init__(self,data_dir,transform = None):
        self.label_name = {"Cat":0,"Dog":1}
        self.data_info = self.get_image_info(data_dir)
        self.transform = transform

    def __getitem__(self, index)
    ##主要获取摸个图像的索引以及文件

    def __len__(self):
    ##主要获取输入文件的个数
        
    def get_image_info(self,data_dir):
    ##主要将文件的索引和文件名放入一个列表中返回
      

可以看我的注释,大概了解一下每个函数的含义,当然封装好了之后过两天就忘记是用来干啥的了哈哈哈哈哈哈哈 

需要注意的是:我封装的这个Dataset类只适用于:

1、同一种类在一个文件夹内

2、文件夹名称的就是该类名

好了,到目前为止,这就基本定义好了一个Dataset类

剩下来的就很简单了:

一般在训练过程中,针对数据集的数据增强这一块一般在定义类之前就定义好,因为在实例化的时候,需要用到:

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),

])
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

以上就是数据增强部分:

训练集和测试集的数据增强是不一样的,大家可以看一下具体代码:

也有很多其他的数据增强方式,可以参照官方文档去查看

顺便说一下,

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

这两行是干嘛用的呢

transforms.Normalize(norm_mean, norm_std)

大家可以看到,在Normalize这边用到了这两个值,主要是对图像进行归一化的处理,方便网络优化的,简单来说,做这个处理,网络就更容易拟合

回归正题:

数据增强部分做完了,就要开始做数据实例化了:

train_data = RMBDataset(data_dir=train_dir, transform = train_transform)
val_data = RMBDataset(data_dir=valid_dir, transform = valid_transform)

可以注意一下参数:一个就是文件的路径,另一个就是数据增强的部分

到此为止,我们的Dataset就讲完了,

那么既然数据都已经导入进来了,DataLoader又是干嘛用的呢?

看名字就大概知道了:数据加载,Datase类把数据全部加载进来了,我们不是一次性把数据全喂给网络

而是:一次Epoch,一次全部的数据,那一次Epoch又分为多少次呢?

这就取决于Batch_size是多大,加入数据总共有100个,Batch_size是10,那一次Epoch就分成了十次输入数据

所以DataLoader其实就是把数据分批输入网络的进行训练

train_loader = DataLoader(dataset=train_data,batch_size=Batch_size,shuffle=True)
val_loader = DataLoader(dataset=val_data,batch_size=Batch_size,shuffle=False)
shuffle这个参数是干嘛的呢,就是每次输入的数据要不要打乱,一般在训练集打乱,增强泛化能力

验证集就不打乱了

至此,Dataset 与DataLoader就讲完了

最后附上全部代码,方便大家复制:

import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from torchvision import transforms

os.environ["CUDA_VISIBLE_DEVICES"] = '0'



###数据读取
split_dir = os.path.join('线性回归', '阿里天池大赛蒸汽预测')
train_dir = os.path.join(split_dir, 'zhengqi_test.txt')
valid_dir = os.path.join(split_dir, 'zhengqi_train.txt')

###数据增强,翻转,裁剪等
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),

])
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
##定义一个datast的类
class RMBDataset(Dataset):
    def __init__(self,data_dir,transform = None):
        self.label_name = {"1":0,"100":1}
        self.data_info = self.get_image_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img,label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为tensor等等

        return img,label
    def __len__(self):
        return len(self.data_info)
    def get_image_info(self,data_dir):
        data_info = list()
        for root , dirs,_ in os.walk(data_dir):
            for sub_dirs in dirs :
                img_names = os.listdir(os.path.join(root, sub_dirs))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dirs, img_name)
                    label = self.label_name[sub_dirs]
                    data_info.append((path_img, int(label)))
        return data_info
Batch_size = 0
train_data = RMBDataset(data_dir=train_dir, transform = train_transform)
val_data = RMBDataset(data_dir=valid_dir, transform = valid_transform)

train_loader = DataLoader(dataset=train_data,batch_size=Batch_size,shuffle=True)
val_loader = DataLoader(dataset=val_data,batch_size=Batch_size,shuffle=False)

猜你喜欢

转载自blog.csdn.net/weixin_53374931/article/details/130091654