2. PyTorch中数据的读取 - Dataset

Dataset demo

    1. 前置基础知识
      os库的使用
        import os
        dir_path = "dataset/train/ants"
        img_path_list = os.listdir(dir_path)
        #此时img_path_list为文件列表
        img_path_list[0]
    
        >>> 'xxx.jpg'
    
    1. 本文使用的数据集样式
      数据集格式

其中 i m a g e image 文件夹中为jpg格式图片, l a b e l label 文件夹中为txt文本。
一般设置 r o o t _ d i r root\_dir 为所有数据集的根目录, 设置 t r a i n i m g train_img 为数据相对根目录的地址,设置 l a b e l _ d i r label\_dir 为标签相对根目录的地址(名字随意,自己能区分就好)。
使用 o s . p a t h . j o i n ( r o o t _ d i r , l a b e l _ d i r ) os.path.join(root\_dir, label\_dir) 函数将他们拼接起来。

    1. demo
from torch.utils.data import Dataset
import os
from PIL import Image


class MyData(Dataset):

    # 数据类初始化
    def __init__(self, root_dir, train_img_dir_name, label_dir):
        # 定义数据主文件夹地址
        self.root_dir = root_dir
        # 定义标签完整地址(根据不同数据做调整)
        self.label_dir = os.path.join(self.root_dir, label_dir)
        # 定义图片文件夹路径名
        self.train_img_dir_name = train_img_dir_name
        # 地址拼接
        self.path = os.path.join(self.root_dir, self.train_img_dir_name)
        # 数据列表
        self.img_path = os.listdir(self.path)
        self.img_label = os.listdir(self.label_dir)

    # 得到单个数据(图片)
    def __getitem__(self, idx):
        # 从列表中得到单一数据的name
        img_name = self.img_path[idx]
        # 获取标签
        img_label = self.img_label[idx]
        # 得到数据地址
        img_item_path = os.path.join(self.root_dir, self.train_img_dir_name, img_name)
        # 使用PIL中的Image库打开图片
        img = Image.open(img_item_path)
        # 获取标签,从对应的txt文件中读取标签名
        img_item_label = os.path.join(self.label_dir, img_label)
        self.label = ""
        with open(img_item_label, "r") as f:
            self.label = f.read()
        return img, self.label

    # 得到数据大小
    def __len__(self):
        return len(self.img_path)


# 创建一个实例
root_dir = 'dataset/train/'
ants_train_img = 'ants_image'
ants_label_dir = "ants_label"
ants_dataset = MyData(root_dir, ants_train_img, ants_label_dir)

# 创建第二个实例
bees_train_img = 'bees_image'
bees_label_dir = 'bees_label'
bees_dataset = MyData(root_dir, bees_train_img, bees_label_dir)

# 测试是否得到蚂蚁图片
# img, label = ants_dataset[0]
# img.show()
# print(label)
# print(len(ants_dataset))

# 测试是否得到蜜蜂图片
# img, label = bees_dataset[0]
# img.show()
# print(label)
# print(len(bees_dataset))

# 定义完整训练数据集
train_dataset = ants_dataset + bees_dataset
print(len(train_dataset))

# 测试
img, label = train_dataset[124]
img.show()
print(label)
print(len(train_dataset))

发布了13 篇原创文章 · 获赞 0 · 访问量 91

猜你喜欢

转载自blog.csdn.net/qq_35283167/article/details/104545815
今日推荐