03_Entrenamiento del modelo PyTorch [La clase de conjunto de datos lee el conjunto de datos]

PyTorch lee imágenes principalmente a través de la clase Dataset , así que primero entendamos brevemente la clase Dataset . conjunto de datos
La clase existe como la clase base de todos los conjuntos de datos , y todos los conjuntos de datos deben heredarla, de forma similar a la base virtual en C++
amable.

 

Aquí nos enfocamos en la función getitem , getitem recibe un índice y luego devuelve los datos de la imagen y las etiquetas, esto
índice generalmente se refiere al índice de una lista , y cada elemento de esta lista contiene la ruta y la información de la etiqueta de los datos de la imagen
aliento.
Sin embargo, cómo hacer esta lista , el método habitual es almacenar la ruta y etiquetar la información de la imagen en un txt
, y luego lea de ese txt .
Entonces el proceso básico de leer sus propios datos es:
1. Cree un txt que almacene la ruta y la información de la etiqueta de la imagen
2. Convierta esta información en una lista , y cada elemento de la lista corresponde a una muestra
3. A través de la función getitem , lea los datos y las etiquetas y devuelva los datos y las etiquetas
Por lo tanto, para permitir que PyTorch lea su propio conjunto de datos, solo se requieren dos pasos:
1. Haz un índice de datos de imágenes
2. Cree una subclase de conjunto de datos
1. Generar código de Bloc de notas
import os
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
'''
    为数据集生成对应的txt文件
'''
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
train_dir = os.path.join(base_dir, "Data", "train")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
valid_dir = os.path.join(base_dir, "Data", "valid")
print(train_txt_path)
print(train_dir)
print(valid_txt_path)
print(valid_dir)
def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)

2. Efecto

 

 3. Código de clase del conjunto de datos

class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r') 
        imgs = []
        for line in fh:
            line = line.rstrip() #rstrip函数返回字符串副本,该副本是从字符串最右边删除了参数指定字符后的字符串,不带参数进去则是去除最右边的空格
            words = line.split() #默认以空格为分隔符
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        # transform 是一个 Compose 类型,里边有一个 list,list 中就会定义了各种对图像进行处理的操作,
        #可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作
        #在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),
        #最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理
        #并不会生成新的一份图片,而是“覆盖”原图
        self.target_transform = target_transform
        self.transform = transform 

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        #对图片进行读取
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

4.carga de datos

Cuando se crea Mydataset , las operaciones restantes se transferirán a DataLoder . En DataLoder , se activará
La función getiterm en Mydataset lee los datos y las etiquetas de una imagen y los une en un lote para devolverlos, como
La entrada real del modelo.

 

Supongo que te gusta

Origin blog.csdn.net/zhang2362167998/article/details/128807544
Recomendado
Clasificación