Pytorch学习(五) --- torchvision.ImageFolder()的用法

在使用pytorch做深度学习任务的数据加载时,常用的方式是使用torchvision.Dataset类定义数据读取,然后使用torch.utils.data.DataLoader定义数据加载器。该部分内容见Pytorch学习(一)

不过,有些分类数据的文件目录组织形式如下:
在这里插入图片描述
即默认你的数据集已经自觉按照要分配的类型分成了不同的文件夹,一种类型的文件夹下面只存放一种类型的图片。
这时候,定义数据读取时,使用 torchvision包中的ImageFolder类会比Dataset类会更方便。

ImageFolder

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

一个通用数据加载器,其中图像以这种方式排列:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

参数:
root (string) – 指定图片存储的路径

transform (callable, optional) – 一个transform函数,接受PIL.Image图像并返回一个转换后的图片

target_transform (callable, optional) –一个函数,输入为target,输出对其的转换。

loader (callable, optional) – A function to load an image given its path.

is_valid_file – 该函数获取图像文件的路径并检查该文件是否为有效文件

成员变量:

  • self.classes - 用一个list保存 类名

  • self.class_to_idx - 类名对应的 索引

    扫描二维码关注公众号,回复: 11194800 查看本文章
  • self.imgs - 保存(img-path, class) tuple的list。

return self.imgs即相当于Dataset类中的return (img, target)

例子:


# 指定读取的图片路径
train_root = './train/
# transform函数组合
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])


# 使用ImageFolder读取数据
all_data =  torchvision.datasets.ImageFolder(
        root=train_root,
        transform=train_transform
    )


# 定义数据加载器
train_set = torch.utils.data.DataLoader(
    all_data,
    batch_size=BTACH_SIZE,
    shuffle=True
)

参考
https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder
https://www.jb51.net/article/180916.htm

原创文章 96 获赞 24 访问量 3万+

猜你喜欢

转载自blog.csdn.net/c2250645962/article/details/105291782