03_PyTorch model training [Dataset class reads dataset]

PyTorch reads pictures mainly through the Dataset class , so let's briefly understand the Dataset class first. Dataset
The class exists as the base class of all datasets , and all datasets need to inherit it, similar to the virtual base in C++
kind.

 

Here we focus on the getitem function, getitem receives an index , and then returns the image data and labels, this
index usually refers to the index of a list , and each element of this list contains the path and label information of the image data
breath.
However, how to make this list , the usual method is to store the path and label information of the picture in a txt
, and then read from that txt .
Then the basic process of reading your own data is:
1. Make a txt that stores the path and label information of the image
2. Convert this information into a list , and each element of the list corresponds to a sample
3. Through the getitem function, read the data and tags, and return the data and tags
Therefore, to enable PyTorch to read its own dataset, only two steps are required:
1. Make an index of image data
2. Build a Dataset subclass
1. Generate Notepad code
import os
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
'''
    为数据集生成对应的txt文件
'''
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
train_dir = os.path.join(base_dir, "Data", "train")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
valid_dir = os.path.join(base_dir, "Data", "valid")
print(train_txt_path)
print(train_dir)
print(valid_txt_path)
print(valid_dir)
def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)

2. Effect

 

 3. Dataset class code

class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r') 
        imgs = []
        for line in fh:
            line = line.rstrip() #rstrip函数返回字符串副本,该副本是从字符串最右边删除了参数指定字符后的字符串,不带参数进去则是去除最右边的空格
            words = line.split() #默认以空格为分隔符
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        # transform 是一个 Compose 类型,里边有一个 list,list 中就会定义了各种对图像进行处理的操作,
        #可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作
        #在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),
        #最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理
        #并不会生成新的一份图片,而是“覆盖”原图
        self.target_transform = target_transform
        self.transform = transform 

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        #对图片进行读取
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

4.dataload

When Mydataset is built, the rest of the operations will be handed over to DataLoder . In DataLoder , it will trigger
The getiterm function in Mydataset reads the data and labels of a picture, and stitches them into a batch to return, as
The real input of the model.

 

Guess you like

Origin blog.csdn.net/zhang2362167998/article/details/128807544