使用pytorch加载数据集和对数据集进行处理

目录

1.torchvision中加载数据集

2.重写Dataset类加载数据集

3.transforms

4.Dataloader对数据进一步处理


1.torchvision中加载数据集

官方文档给出的数据

下面以CIFAR数据集为例子:

torchvision.datasets.CIFAR10(root: str, train: bool = True,
 transform: Optional[Callable] = None, target_transform: 
Optional[Callable] = None, download: bool = False)
  • root:表示数据集的路径
  • train:表示是否为训练集,为True表示为训练数据集,否则为测试集。
  • transform:表示对数据集进行转换,下面已经对该功能进行了说明。
  • target_transform:表示对target进行数据转换。
  • download:是否下载,如果为True的话,表示从网上进行下载该数据集,否则从已有的文件目录下面获取。
import os
import torch
import numpy as np
from PIL import Image
from torchvision import datasets,transforms

#数据集的预处理
transform=transforms.Compose([
    transforms.ToTensor()
])

#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
train_data=datasets.CIFAR10(root=root,train=True,transform=transform,download=True)
test_data=datasets.CIFAR10(root=root,train=False,transform=transform,download=True)

print('trainSize: {}'.format(len(train_data)))
print('testSize: {}'.format(len(test_data)))

#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
test_data=datasets.CIFAR10(root=root,train=False,download=True)

#显示图片的类别
print('图片包含类别: {}'.format(train_data.classes))

#显示图片
imgOne,target=test_data[0]
imgOne.show()
#查看图片所属类别
print('class: {}'.format(test_data.classes[target]))

 

2.重写Dataset类加载数据集

官方文档torch.utls.data.Dataset

 

import os
import pathlib

from PIL import Image
from torch.utils.data import Dataset

class myDataset(Dataset):
    def __init__(self,img_path):
        self.data_dir=pathlib.Path(img_path)
        self.dataset=list(self.data_dir.glob('*/*.jpg'))

    #根据索引index获取数据,index是根据数据集的顺序来获取的
    def __getitem__(self, index):
        img=self.dataset[index]
        imgTo=Image.open(img)
        return imgTo
    
    #获取数据集的大小
    def __len__(self):
        #统计flower_photos文件夹下面所有的图片数据集数量
        self.len=len(list(self.data_dir.glob('*/*jpg')))
        return self.len

if __name__ == '__main__':
    mydataset=myDataset(img_path=r'E:\myDataset\flower_photos')
    print('dataSize: {}'.format(len(mydataset)))
    
    #获取第1张图片数据
    img=mydataset[0]
    print('imgsize: {}'.format(img.size))
    #显示图片
    img.show('img')

 

3.transforms

打开transforms.py文件可以看到其中包含的对数据的处理方法:(关于这些方法要使用的时候都可以直接查询)

torchvision官网查看功能:

https://pytorch.org/vision/stable/transforms.html
 

transforms.compose使用

#compose中包含一个数组,数组中包含的是对图片数据集进行处理的过程
#比如下面,首先对一张图片进行中心的裁剪,其次将PIL数据类型转换为Tensor数据类型,
#最后是将数据转换为浮点类型
transf=transforms.Compose([
    transforms.CenterCrop(10),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float)
])
print('imgSize: {}'.format(img.size))
print(type(transf(img)))


如果读者不使用.compose的话,也可以使用下面的方法一步一步的进行数据转化:

import os
import torch
from PIL import Image
from torchvision import transforms,datasets

img_path="myDataset/flower_photos/daisy/5547758_eea9edfd54_n.jpg"
#读取图片数据
img=Image.open(img_path)
#显示图片大小
print('imgSize: {}'.format(img.size))
#第一步:对图片数据进行中心裁剪
centerCut=transforms.CenterCrop(100)
img_cut=centerCut(img)
img_cut.show('img_cut')
print('imgCutSize: {}'.format(img_cut.size))

#第二步:将图片数据集转换为Tensor
ToTensor=transforms.ToTensor()
img_ToTensor=ToTensor(img_cut)
print(type(img_ToTensor))

#第三步:将Tensor数据转换为浮点类型
FloatData=transforms.ConvertImageDtype(dtype=torch.float)
img_Float=FloatData(img_ToTensor)
print(type(img_Float))

关于上面一些自己比较常用的一些方法 :但是读者应该注意的是,在Compose中使用这些方法时,对数据的处理先后顺序注意,因为有些方法要传入的是Tensor数据类型,所以将数据转换为Tensor类型方法可能得放在其他方法的前面,注意报错的问题所在。

transform=transforms.Compose([
    transforms.Resize(size=[224,224]),
    transforms.CenterCrop(100),
    transforms.ToTensor(),
    #output[channel] = (input[channel] - mean[channel]) / std[channel],由于图片是三通道的,所以平均值和方差都是分别给出三个值
    transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]),
    #p表示水平翻转的概率
    transforms.RandomHorizontalFlip(p=0.5),
    #垂直翻转
    transforms.RandomVerticalFlip(p=0.5),
    #随机旋转,degrees表示旋转度数,center表示旋转中心坐标,还有其他的参数可以自行选择
    transforms.RandomRotation(degrees=45,center=[50,50],)
])

4.Dataloader对数据进一步处理

官网解释:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

以下给出的是一些常见设置参数:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, num_workers=0)
  • dataset:自己的数据集。
  • batch_size:每一次加载的数据量。
  • shuffle:是否随机打散;当为True时,随机打散,否则默认。
  • num_workers:加载数据所使用的进程数,如果为0,表示默认使用主进程。
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms,datasets

#数据集的预处理
transform=transforms.Compose([
    transforms.ToTensor()
])

#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
train_data=datasets.CIFAR10(root=root,train=True,transform=transform,download=True)
test_data=datasets.CIFAR10(root=root,train=False,transform=transform,download=True)
#获取图片类别
classes=train_data.classes

#加载数据集
train_loader=DataLoader(dataset=train_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

for data in train_loader:
    imgs,targets=data
    print('imgs: {}'.format(imgs.shape))
    print('target: {}'.format(targets))
    #打印前四张打包的图片类别
    for stop,i in enumerate(targets):
        print('target[{}]---->{}'.format(i,classes[i]))

猜你喜欢

转载自blog.csdn.net/Keep_Trying_Go/article/details/128482455