Pytorch学习笔记一:制作、加载自己的图像数据集

Pytorch学习笔记一:制作、加载自己的图像数据集



前言

首先介绍如何用pytorch加载网络现有数据集,然后介绍如何制作自己的图像数据集并批量读取来训练自己的网络。


提示:以下是本篇文章正文内容,下面案例可供参考

一、下载数据集

使用Pytorch进行读取本地的MINIST数据集并进行装载

# 训练数据和测试数据的下载
trainDataset = torchvision.datasets.MNIST( # torchvision可以实现数据集的训练集和测试集的下载
  root="./data", # 下载数据,并且存放在data文件夹中
  train=True, # train用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分;如果设置为False,则说明载入的是该数据集的测试集部分。
  transform=transforms.ToTensor(), # 数据的标准化等操作都在transforms中,此处是转换
  download=True 
)
 
testDataset = torchvision.datasets.MNIST(
  root="./data",
  train=False,
  transform=transforms.ToTensor(),
  download=True
)

二、加载自己的数据集

1.制作数据集

训练神经网络需要标准输入图像和它的真值标签。
在分类问题中,比如猫、狗、船、车等等,我们可以用数字代表不同的分类。可以制作一个txt文档用于存放输入图像的地址和它对应的标签数字。
我现在有个任务需要以图像作为输入,以另一张处理过后的图像作为它的真值,所以我在txt文本下面写的是它们的路径。在项目路径下新建了一个train文件夹用于放训练图片,并在train文件夹下新建一个训练的txt用于标注训练图像和标签图像
在这里插入图片描述

2.加载数据集

Dataset类

PyTorch读取图片,主要是通过Dataset类是Pytorch中所有数据集加载类中应该继承的父类。我们通过继承改写Dataset类来读取自己的图像数据集。其中以下三个函数必须改写:
__init__方法里面进行读取数据文件

__getitem__方法进行支持下标访问

__len__方法返回自定义数据集的大小,方便后期遍历

class OpticalSARDataset(Data.Dataset):
    """
      定义自己的数据集、读取数据、初始化数据
    """

    def __init__(self, data_dir, part):
        # 所有图片的绝对路径
        assert part in ["train", "val"]
        self.image_dir = os.path.join(data_dir, part)
        self.img_names = []
        self.label_names = []

        with open(os.path.join(data_dir, part, "label.txt")) as f:
            while True:
                il = f.readline(1500)  # 如果样本数据名称大于1500,修改该值
                if not il:
                    break
                a = il.split(sep=' ')
                self.img_names.append(a[0])
                self.label_names.append(a[1][0:-1])  # remove '\n'
        self.samples_num = len(self.img_names)
        # print(self.samples_num)

        self.transform = torchvision.transforms.Compose([
            # 将 PIL 图片转换成位于[0.0, 1.0]的floatTensor, shape (C x H x W)
            torchvision.transforms.ToTensor()])

    def __len__(self):
        # 返回图像的数量
        return self.samples_num

    def __getitem__(self, index):
        tp_img = Image.open(os.path.join(self.image_dir,  self.img_names[index])
                                 ).convert('RGB')
        tp_label = Image.open(os.path.join(self.image_dir, self.label_names[index])
                               ).convert('RGB')
        # PIL.Image.open 读取的图片数据是RGB格式;
        tp_img = cv2.cvtColor(np.asarray(tp_img), cv2.COLOR_RGB2BGR)
        tp_label = cv2.cvtColor(np.asarray(tp_label), cv2.COLOR_RGB2BGR) # 转换为BGR便于cv2.imshow,跟下面imshow之前RGB2BGR只用一种方法,这里统一为cv2的BGR格式
        img = self.transform(tp_img)
        label = self.transform(tp_label)


        sample = {
    
    
            "label": label,  # shape
            "image": img  # shape: (3, *image_size)
        }


        return sample

定义数据集

# 利用之前创建好的OpticalSARDataset类去创建数据对象
train_dataset = OpticalSARDataset(data_dir, 'train')  # 训练数据集

Dataloader类

之前所说的Dataset类是读入数据集数据并且对读入的数据进行了索引。
但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:
可以分批次读取:batch-size
可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
可以并行加载数据(利用多核处理器加快载入数据的效率)
Dataloader这个类并不需要我们自己设计代码,我们只需要利用DataLoader类读取我们设计好的类。

实例化数据集

# 利用dataloader读取我们的数据对象,并设定batch-size和工作现场
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=0)
batch = iter(train_iter).next()
print(batch["image"].shape, batch["label"].shape)
print(batch["image"][0].shape)

总结

参考博客:
定义自己的数据集
pytorch加载自己数据集
设计自己的数据
训练自己数据完整步骤
Dataset类



猜你喜欢

转载自blog.csdn.net/qq_43173239/article/details/108948228