Explicación detallada de Dataset y DataLoader en torch.utils.data de Pytorch

En el proceso de nuestro aprendizaje profundo, es inevitable utilizar conjuntos de datos, entonces, ¿cómo se cargan los conjuntos de datos en nuestro modelo para el entrenamiento? En el pasado, la mayoría de nuestros principiantes deben haber utilizado el código directamente en Internet, pero aún no está claro cuál es el principio subyacente. Así que hoy haré un análisis detallado de la función Conjunto de datos incorporada y la función Conjunto de datos personalizada.

prefacio

torch.utils.dataes PyTorchun módulo proporcionado para procesar y cargar datos. Este módulo proporciona un conjunto de clases y funciones de utilidad para crear, manipular y cargar conjuntos de datos de forma masiva.

A continuación se muestran torch.utils.dataalgunas clases y funciones de uso común en el módulo:

  • Dataset: Define una clase de conjunto de datos abstracto y los usuarios pueden crear sus propios conjuntos de datos heredándolos de esta clase. DatasetLa clase proporciona dos métodos que deben implementarse: __getitem__para acceder a muestras individuales y __len__para devolver el tamaño del conjunto de datos.
  • TensorDataset: Heredado de la Datasetclase, utilizado para empaquetar datos tensoriales en un conjunto de datos. Toma varios tensores como entrada y determina el tamaño del conjunto de datos de acuerdo con el tamaño del primer tensor de entrada.
  • DataLoader: Clase de cargador de datos, utilizada para cargar conjuntos de datos por lotes. Acepta un objeto de conjunto de datos como entrada y proporciona varias funciones de preprocesamiento y carga de datos, como configurar el tamaño del lote, carga de datos multiproceso y mezcla de datos, etc.
  • Subset: la clase de subconjunto del conjunto de datos, que se utiliza para seleccionar las muestras especificadas del conjunto de datos.
  • random_split: divide aleatoriamente un conjunto de datos en varios subconjuntos; puede especificar la proporción de división o el tamaño de cada subconjunto.
  • ConcatDataset: une varios conjuntos de datos para formar un conjunto de datos más grande.
  • get_worker_info: Obtenga la información del proceso del cargador de datos actual.

Además de las clases y funciones anteriores, torch.utils.datatambién se proporcionan algunas herramientas de preprocesamiento de datos de uso común, como recorte aleatorio, rotación aleatoria, estandarización, etc.

A través de torch.utils.datalas clases y funciones proporcionadas por el módulo, puede cargar, procesar y cargar datos por lotes fácilmente, lo que facilita el entrenamiento y la verificación del modelo. Sin embargo, las dos clases que utilizamos con más frecuencia son Datasetlas clases y DataLoader.

1. Clase de conjunto de datos personalizado

torch.utils.data.DatasetEs una clase abstracta que se utiliza para representar conjuntos de datos en PyTorch y se utiliza para definir el método de acceso y el número de muestras de conjuntos de datos.

La clase Dataset es una clase base, podemos crear una clase de conjunto de datos personalizada heredando esta clase e implementando los dos métodos siguientes:

getitem (self, index): según el índice dado, devuelve los datos de muestra correspondientes. El índice puede ser un número entero, lo que significa obtener muestras en orden, o pueden ser otros métodos, como obtener muestras por nombre de archivo, etc.
len (self): devuelve el número de muestras en el conjunto de datos.

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        # 根据索引获取样本
        return self.data[index]

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 根据索引获取样本
sample = dataset[2]
print(sample)
# 3

El ejemplo de código anterior implementa principalmente un 自定义Dataset数据集类método, que generalmente se define cuando necesitamos entrenar nuestros propios datos. Pero en general, como principiantes en el aprendizaje profundo, usamos MNIST, CIFAR-10 内置数据集, etc. En este momento, no necesitamos definir la clase Dataset nosotros mismos. En cuanto a por qué, lo explicaremos en detalle a continuación.

2、torchvision.conjuntos de datos

Si desea utilizar los conjuntos de datos integrados en PyTorch, normalmente torchvision.datasetslo hace a través de módulos. torchvision.datasetsEl módulo proporciona muchos conjuntos de datos de visión por computadora de uso común, como MNIST, CIFAR10, ImageNet, etc.

A continuación se muestra un código de muestra que utiliza conjuntos de datos integrados:

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

En el código anterior, lo que hemos implementado es la carga y el uso de un conjunto de datos MNIST (dígitos escritos a mano) integrado. Como puede ver, no utilizamos las clases mencionadas anteriormente aquí torch.utils.data.Dataset, ¿por qué?

Esto se debe a que en torchvision.datasetsel módulo, la clase de conjunto de datos incorporada ya implementa torch.utils.data.Datasetla interfaz y devuelve directamente un objeto de conjunto de datos utilizable. Por lo tanto, cuando utilizamos el conjunto de datos integrado, podemos crear una instancia directa de la clase del conjunto de datos integrado sin heredar explícitamente la torch.utils.data.Datasetclase.

Las implementaciones de clases de conjuntos de datos integrados, como , torchvision.datasets.MNISTya contienen definiciones __getitem__y __len__métodos, que nos permiten obtener muestras y determinar el tamaño del conjunto de datos directamente desde el objeto del conjunto de datos integrado. De esta manera, cuando utilizamos el conjunto de datos integrado, podemos pasar directamente el objeto del conjunto de datos integrado torch.utils.data.DataLoaderpara la carga de datos y el procesamiento por lotes.

Detrás de los conjuntos de datos integrados, todavía se torch.utils.data.Datasetimplementan en función de clases. Solo por conveniencia y para proporcionar más funciones, PyTorch encapsula estos conjuntos de datos de uso común en clases de conjuntos de datos integrados.

Con este fin, fui al sitio web oficial de Pytorch para verificar el código de carga del conjunto de datos incorporado, como se muestra en la siguiente figura:
inserte la descripción de la imagen aquí
Se puede ver que la clase de conjunto de datos Dataset está incorporada.

3 、 Cargador de datos

torch.utils.data.DataLoaderEs una clase de herramienta para cargar datos por lotes en PyTorch. Acepta un objeto de conjunto de datos (como torch.utils.data.Datasetuna subclase de) y proporciona varias funciones, como carga de datos, procesamiento por lotes, mezcla de datos, etc.

Los siguientes son torch.utils.data.DataLoaderparámetros y funciones de uso común de:

  • dataset: Objeto de conjunto de datos, que puede ser torch.utils.data.Datasetun objeto de subclase de .
  • batch_size: número de muestras por lote, el valor predeterminado es 1.
  • shuffle: Ya sea para mezclar los datos, el valor predeterminado es False. Los datos se mezclan en cada época.
  • num_workers: Cuántos procesos secundarios se utilizan para cargar datos. El valor predeterminado es 0, lo que significa cargar datos en el proceso principal. De hecho, está configurado en 0 en el sistema Windows, pero se puede configurar en un número mayor que 0 en Linux.
  • collate_fn: Función para procesar cada muestra antes de devolver los datos del lote. En caso afirmativo None, utilice torch.utils.data._utils.collate.default_collatela función de procesamiento de forma predeterminada.
  • drop_last: Si se descartan los datos cuyo último tamaño de muestra es inferior a un lote, el valor predeterminado es False.
  • pin_memory: Ya sea para almacenar los datos cargados en la memoria fija correspondiente a CUDA, el valor predeterminado es False.
  • prefetch_factor: Factor de captación previa, utilizado para captar previamente datos en el dispositivo; el valor predeterminado es 2.
  • persistent_workers: si es verdadero True, utilice un subproceso persistente para cargar datos en cada época; el valor predeterminado es False.

El código de muestra es el siguiente:

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# 使用数据加载器迭代样本
for images, labels in train_loader:
    # 训练模型的代码
    ...

4、torchvision.transforms

torchvision.transformsEl módulo es un módulo funcional para el preprocesamiento de datos de imágenes en PyTorch. Proporciona una serie de funciones de transformación para realizar diversas operaciones comunes de transformación y aumento de datos al cargar, entrenar o inferir datos de imágenes. A continuación se ofrecen explicaciones detalladas de algunas funciones de conversión utilizadas habitualmente:

  1. Cambiar tamaño: cambiar el tamaño de la imagen

    • Resize(size): Cambia el tamaño de la imagen a las dimensiones dadas. Puede aceptar un número entero como tamaño del lado más corto, o una tupla o lista como tamaño objetivo de la imagen.
  2. ToTensor: convierte una imagen en un tensor

    • ToTensor(): Convierte una imagen en un tensor, asignando valores de píxeles que van de 0-255 a 0-1. Adecuado para pasar datos de imágenes a modelos de aprendizaje profundo.
  3. Normalizar: normalizar los datos de la imagen

    • Normalize(mean, std): normaliza los datos de la imagen. La media y el estándar pasados ​​son la media y la desviación estándar para la normalización del valor de los píxeles. Cabe señalar que la media y el estándar deben corresponder al conjunto de datos utilizado anteriormente.
  4. RandomHorizontalFlip: imagen volteada horizontal aleatoriamente

    • RandomHorizontalFlip(p=0.5): voltea aleatoriamente la imagen horizontalmente con una probabilidad determinada. La probabilidad p controla la probabilidad de invertir y el valor predeterminado es 0,5.
  5. RandomCrop: recorta una imagen aleatoriamente

    • RandomCrop(size, padding=None): recorta aleatoriamente una imagen a un tamaño determinado. Se puede proporcionar una tupla o un número entero como tamaño objetivo y, opcionalmente, valor de relleno.
  6. ColorJitter: fluctuación del color

    • ColorJitter(brightness=0, contrast=0, saturation=0, hue=0): ajusta aleatoriamente el brillo, el contraste, la saturación y el tono de la imagen. La apariencia de la imagen se puede ajustar configurando diferentes parámetros.

Cuando lo usamos, a menudo usamos transforms.Composepara combinar estas operaciones de procesamiento de datos. Cuando lo usamos, simplemente llamamos a la combinación directamente.

El código de muestra es el siguiente:

from torchvision import transforms

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)
    transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])

# 对图像进行预处理
image = transform(image)

5. Definición de clase de conjunto de datos Dataset en clasificación de imágenes

Tome el conjunto de datos de enfermedades oculares como ejemplo (para obtener más detalles, consulte el caso básico de la práctica de aprendizaje profundo: Reconocimiento de enfermedades oculares de la red neuronal convolucional (CNN) basado en SqueezeNet | Ejemplo 1 ), en el que generamos el tren después de etiquetar el conjunto de datos. y archivos valid.txt, hay dos columnas en este archivo, la primera columna es la ruta del conjunto de datos y la segunda columna es la etiqueta (es decir, la categoría) del conjunto de datos, de la siguiente manera: En este momento ,
inserte la descripción de la imagen aquí
podemos definir nuestra propia clase de lectura de conjunto de datos, el código específico es el siguiente:

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

transform_BZ = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5]
)


class MyDataset(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag

        self.train_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.RandomHorizontalFlip(),  # 随机左右翻转图像
            transforms.RandomVerticalFlip(),  # 随机上下翻转图像
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])
        self.val_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])

    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))
        return imgs_info

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]

        img_path = os.path.join('', img_path)
        img = Image.open(img_path)
        img = img.convert("RGB")
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)
        return img, label

    def __len__(self):
        return len(self.imgs_info)

Después de definir nuestra propia clase de lectura del conjunto de datos, podemos pasar nuestro archivo txt para preprocesar y leer el conjunto de datos. En nuestra clase de conjunto de datos personalizado, los tres métodos más importantes son __init__(), getitem () y __len__(), los cuales son indispensables. Al mismo tiempo, la operación de mejora de datos mediante transformaciones no es necesaria. Esta es solo una forma de mejorar el rendimiento del modelo, pero nuestro proceso de entrenamiento del modelo actual generalmente agrega operaciones de mejora de datos .

# 加载训练集和验证集
train_data = MyDataset(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)
test_data = MyDataset(r"F:\SqueezeNet\valid.txt", False)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)

Arriba, cargamos nuestro archivo train.txt y nuestro archivo valid.txt respectivamente a través de la clase MyDataset personalizada (el siguiente parámetro True significa que queremos mejorar los datos del conjunto de entrenamiento, mientras que False significa mejorar los datos del conjunto de verificación) . Luego, usamos nuestro DataLoader para cargar por lotes el conjunto de datos y luego podemos arrojar directamente los datos cargados train_dl al test_dlmodelo para entrenar.


Ejemplos específicos pueden referirse a:

Supongo que te gusta

Origin blog.csdn.net/m0_63007797/article/details/132385283
Recomendado
Clasificación