pytorch加载自己的数据集

通用数据加载器

官方给出的,可以不局限于给定的数据集,加载自己的数据集。

CLASS torchvision.datasets.DatasetFolder(
		root: str, 
		loader: Callable[[str], Any], 
		extensions: Union[Tuple[str, ...], NoneType] = None, 
		transform: Union[Callable, NoneType] = None, 
		target_transform: Union[Callable, NoneType] = None, 
		is_valid_file: Union[Callable[[str], bool], NoneType] = None
		)None
		

参数含义:

  • root(string)–根目录路径。
  • loader(callable)–在给定路径的情况下加载样本的函数。
  • extensions(tuple [string])–允许的扩展名列表。 扩展名和is_valid_file不应同时传递。
  • transform (callable, optional)–接收样本并返回转换版本的函数/转换。 例如对图像进行transforms.RandomCrop。
  • target_transform(callable, optional)–接收目标并对其进行转换的函数/转换。
  • is_valid_file –接受文件路径并检查文件是否为有效文件(用于检查损坏的文件)的函数,不应同时传递扩展名和is_valid_file。

文件夹组织:

  • 应以如下结构组织文件
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/[...]/asd932_.ext

通用图像数据加载器

官方给出的,可以不局限于给定的图像数据集,加载自己的图像数据集。

CLASS torchvision.datasets.ImageFolder(
		root: str, 
		transform: Union[Callable, NoneType] = None, 
		target_transform: Union[Callable, NoneType] = None, 
		loader: Callable[[str], Any] = <function default_loader>, 
		is_valid_file: Union[Callable[[str], bool], NoneType] = None
		)
		

参数含义:

  • root(string)–根目录路径。
  • transform (callable, optional)–接收PIL图像并返回转换版本的函数/转换。 例如对图像进行transforms.RandomCrop。
  • target_transform(callable, optional)–接收目标并对其进行转换的函数/转换。
  • loader(callable)–在给定路径的情况下加载样本的函数。
  • is_valid_file –接受文件路径并检查文件是否为有效文件(用于检查损坏的文件)的函数。

文件夹组织:

  • 应以如下结构组织文件
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png

附:以上类的源代码

自己定义的数据加载器

  1. 常规思路
# 读取文件位置
def get_path('path-str'):
    ...
    return file_path

# 读取图片
def loader_img(file_path):
    # 根据图片的位置读取图片并返回读取的图片和标签
    # 对图片进行处理
    ...
    return imgs_list, label_list

# 获取batchsize大小的数据
def get_train_data(imgs_list,label_list,batchsize):
    ...
    return img[1],img[2],...,img[batchsize]
    

常规思路无法将加载出来的数据集使用pytorch的DataLoader加载,无法以batch的形式去训练,故可以按照pytorch中的Dataset类,写一个自己的类。

  1. 模仿pytorch中的类,定义自己的类
class MyDataset(torch.utils.data.Dataset): # 需要继承torch.utils.data.Dataset
    def __init__(self):
        # 初始化文件路径或文件名列表。
        # 初始化该类的一些基本参数。
        pass
        
    def __getitem__(self, index):
        # TODO1.从文件中读取一个数据(例如,plt.imread)。
        #2.预处理数据(例如torchvision.Transform)。
        #3.返回数据对(例如图像和标签)。
        # 这里需要注意的是,第一步:read one data,是一个data
        pass
        
    def __len__(self):
        # 返回数据集的总大小。
        

这种方法中,标签信息一般在文件名中,可以使用函数将文件名的标签信息存储起来。

  1. 一种更加具有鲁棒性的类的实现

gen_txt函数

# coding:utf-8
import os
'''
    为数据集生成对应的txt文件
'''

train_txt_path = os.path.join("..", "..", "Data", "train.txt")
train_dir = os.path.join("..", "..", "Data", "train")

valid_txt_path = os.path.join("..", "..", "Data", "valid.txt")
valid_dir = os.path.join("..", "..", "Data", "valid")


def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()


if __name__ == '__main__':
    gen_txt(train_txt_path, train_dir)
    gen_txt(valid_txt_path, valid_dir)

MyDataset类

扫描二维码关注公众号,回复: 13451997 查看本文章
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)
        

方法就是先使用gen_txt函数生成数据集的txt文档,在MyDataset类中使用,生成DataLoader可以直接加载的数据集。

代码来源:Pytorch模型训练实用教程

  1. 标签为图像的情况

这个问题困扰了我很长时间,很感谢好心人,直接附上链接:Pytorch 构建自己的数据集 输入与标签皆为图片

猜你喜欢

转载自blog.csdn.net/qq_45510888/article/details/115508305