pytorch构建自己的图像数据集

pytorch构建自己的图像数据集

数据准备

  Pytorch读取和载入数据有专门的DatasetDateloader类,但是当我们想读取自己的数据集时,Dataset类就不能用了,因此这篇博客教大家如何创建自己的数据集。在开始工作之前需要准备好自己的图像数据集,这里使用cifar10数据集为例,cifar10是一个十分类的公开数据集,拥有6w张32*32的图像,该数据集结构如下:

|-cifar10
|-----train
|---------airplane
|---------automobile
|---------…
|-----test
|---------airplane
|---------automobile
|---------…
  你可以把自己的数据集也按照这样分类,数据集准备好后,我们就可以进行下一步处理了。总体的思路是把各个图像的路径和其对应的标签构成一个列表,这样就可以利用pytorch自带的Dataloder类进行读取。

获取数据

  利用glob包可以很方便的获取我们想要的路径,首先我们需要导入相应的包:

import glob

  分别读取训练集和测试集的路径:

train_imgs_path = glob.glob(r'F:\jupyterFile\pycharmcode\cifar\train\*\*.png')
test_imgs_path = glob.glob(r'F:\jupyterFile\pycharmcode\cifar\test\*\*.png')#路径前的r为转义标志,否则路径可能报错

  这里大家可以print一下看看输出的是什么。路径获取后,我们就要给每一个路径赋予一个标签(0,1,2,…),首先设定一下有哪些类,注意类的名称一定要和文件夹的名称对应,cifar有10类:

species = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

  接着创建两个空列表,把标签按照路径的顺序进行排列:

train_labels = []
test_labels = []  # 用于生成标签
# 对所有图片路径进行迭代
# 为训练集添加标签
for img in train_imgs_path:
# 区分出每个img,应该属于什么类别
	for i, c in enumerate(species):
		if c in img:
			train_labels.append(i)
# 为测试集添加标签
for img in test_imgs_path:   
	for i, c in enumerate(species):
		if c in img:
			test_labels.append(i)  # 为对应的数据集增加标签

  实现的过程很简单,就是使用之前设定的类别名称在路径里搜索,从而为对应的类别赋予标签。到此我们已经拥有了训练数据和测试数据的路径及其对应的标签,下一步就是调用pytorch Dataset类了。

重写Dataset类

  这里我们需要重写pytorch自带的Dataset类,便于读取我们自己的数据,同时还需要导入相应的库,代码在最后,下面是重写后的Dataset类:

class Mydatasetpro(torch.utils.data.Dataset):
        # 初始化函数,得到数据
    def __init__(self, data_root, data_label, transform):
        self.data = data_root
        self.label = data_label
        self.transforms = transform
        # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        pil_img = Image.open(data)
        data = self.transforms(pil_img)
        return data, labels
        # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

数据载入

  重写之后,我们就可以利用pytorch自带的DataLoader进行数据的读取和载入了:

batchsz=32 #这里设置批次数量
data_train = Mydatasetpro(train_imgs_path, train_labels, transform)  # 训练数据读取
data_train_loader = DataLoader(data_train, batch_size=batchsz, shuffle=True)  # 训练数据载入,训练时数据标签要打乱
data_test = Mydatasetpro(test_imgs_path, test_labels, transform)  # 测试数据读取
data_test_loader = DataLoader(data_test, batch_size=batchsz, shuffle=False)  # 测试数据载入

  我们得到的data_train_loaderdata_test_loader 就是最终处理好的数据,可以直接输入后续的模型。

代码

  最后放一下整体的代码如下,在代码里我把相应的过程用函数data_get封装了。有问题可以留言哈!

from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import glob
from PIL import Image
import matplotlib.pyplot as plt
# 重构Dataset类
class Mydatasetpro(torch.utils.data.Dataset):
        # 初始化函数,得到数据
    def __init__(self, data_root, data_label, transform):
        self.data = data_root
        self.label = data_label
        self.transforms = transform
        # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        pil_img = Image.open(data)
        data = self.transforms(pil_img)
        return data, labels
        # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

def data_get(train_imgs_path,test_imgs_path,species,batchsz,transform):
    train_labels = []
    test_labels = []  # 用于生成标签
    # 对所有图片路径进行迭代
    for img in train_imgs_path:
        # 区分出每个img,应该属于什么类别
        for i, c in enumerate(species):
            if c in img:
                train_labels.append(i)
    for img in test_imgs_path:
        # 区分出每个img,应该属于什么类别
        for i, c in enumerate(species):
            if c in img:
                test_labels.append(i)  # 为对应的数据集增加标签
    data_train = Mydatasetpro(train_imgs_path, train_labels, transform)  # 训练数据读取
    data_train_loader = DataLoader(data_train, batch_size=batchsz, shuffle=True)  # 训练数据载入,训练时数据标签要打乱

    data_test = Mydatasetpro(test_imgs_path, test_labels, transform)  # 测试数据读取
    data_test_loader = DataLoader(data_test, batch_size=batchsz, shuffle=False)  # 测试数据载入
    return data_train_loader,data_test_loader
  
if __name__ == '__main__':
    # 训练集和测试集数据路径
    train_imgs_path = glob.glob(r'F:\jupyterFile\pycharmcode\cifar\train\*\*.png')
    test_imgs_path = glob.glob(r'F:\jupyterFile\pycharmcode\cifar\test\*\*.png')

    # 输入类别
    species = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    # 图像变换
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ])
    batchsz = 128  # 定义批处理量
    data_train_loader, data_test_loader=data_get(train_imgs_path,test_imgs_path,species,batchsz,transform)
    imgs_batch, labels_batch = next(iter(data_train_loader))  # 迭代方法获取批数据
    print(imgs_batch.shape)

    # 测试下数据集里的图片
    plt.figure(figsize=(12, 8))
    for i, (img, label) in enumerate(zip(imgs_batch[:32], labels_batch[:8])):
        img = img.permute(1, 2, 0).numpy()
        plt.subplot(2, 4, i+1)
        plt.xlabel(species[label.numpy()])
        plt.imshow(img)
    plt.show()

猜你喜欢

转载自blog.csdn.net/weixin_44598249/article/details/128421218