Explicación detallada de Dataset y DataLoader

Conjunto de datos y cargador de datos

1. Explicación oficial (Google Translate):
el código que procesa muestras de datos puede volverse desordenado y difícil de mantener; idealmente, queremos que nuestro código de conjunto de datos esté separado de nuestro código de entrenamiento modelo para una mejor legibilidad y modularidad.
PyTorch proporciona dos primitivas de datos: torch.utils.data.DataLoadery torch.utils.data.Datasetque nos permiten usar conjuntos de datos precargados, así como nuestros propios datos. El conjunto de datos almacena las muestras y sus etiquetas correspondientes , y el cargador de datos envuelve un conjunto de datos de objetos iterables para facilitar el acceso a las muestras.
2. Dataset
es una plantilla para todos los conjuntos de datos utilizados por todos los desarrolladores para entrenamiento y pruebas.
Conjunto de datos define el contenido del conjunto de datos, que es equivalente a una estructura de datos similar a una lista con una cierta longitud, y puede usar el índice para obtener los elementos del conjunto de datos.
DataLoader define un método para cargar conjuntos de datos por lotes. Es un objeto iterable que implementa el método __iter__ y genera un lote de datos en cada iteración.
3. DataLoader
DataLoader puede controlar el tamaño del lote, el método de muestreo de elementos en el lote y el método de clasificación de los resultados del lote en el formulario de entrada requerido por el modelo, y puede usar múltiples procesos para leer datos.
En la mayoría de los casos, solo necesitamos implementar el método __len__ y el método __getitem__ de Dataset, puede crear fácilmente su propio conjunto de datos y cargarlo con la canalización de datos predeterminada.

1. Conjunto de datos personalizado

La clase Dataset personalizada debe heredar la clase DataSet oficial de pytorch y también debe implementar tres funciones: __init__ , __len__ y __getitem__ .
init: inicialización (generalmente necesita pasar la ruta del archivo del conjunto de datos , en qué ruta guardar el archivo , función de preprocesamiento )
len: devuelve el tamaño del conjunto de datos
getitem: devuelve las características y etiquetas de la muestra de acuerdo con el índice .

import os.path

import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image


class MyImageDataset(Dataset):
    def __init__(self, annotations_file, data_dir, transform=None, target_transform=None):
        # annotations_file:文件路径
        # data_dir: 将文件保存到哪个路径
        self.data_label = pd.read_csv(annotations_file)
        self.data_dir = data_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        # 返回数据集总的大小
        return len(self.data_label)

    def __getitem__(self, item):
        data_name = os.path.join(self.data_dir, self.data_label.iloc[item, 0])
        image = read_image(data_name)
        # 对特征进行预处理
        label = self.data_label.iloc[item, 1]
        if self.transform:
            image = self.transform(image)
        # 对标签进行预处理
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

De hecho, solo necesitamos modificar los cuatro parámetros de annotations_file, data_dir, transform (preprocesamiento de características), target_transform (preprocesamiento de etiquetas) .
El conjunto de datos procesa solo una muestra a la vez y devuelve una característica y la etiqueta correspondiente a la característica

2. Use DataLoaders para preparar datos para el entrenamiento

Recupere las características del conjunto de datos de nuestro conjunto de datos y etiquete una muestra a la vez. Cuando entrenamos un modelo, normalmente queremos pasar muestras en " mini lotes ", reorganizar en cada época (cuántas veces por iteración) para reducir el sobreajuste del modelo y usar el multiprocesamiento de Python para acelerar la recuperación de datos.

batch_size: el número de muestras seleccionadas para un entrenamiento
shuffle = True: ordena aleatoriamente los datos después de cada ciclo de entrenamiento

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

3. Datos iterativos

Hemos cargado este conjunto de datos en el DataLoader y podemos iterar sobre el conjunto de datos según sea necesario . Cada iteración a continuación devuelve un lote de train_features y train_labels (con batch_size=64 características y etiquetas, respectivamente).
El método iter() obtiene un iterador.
El método next() obtiene las funciones y las etiquetas a su vez.

train_features, train_labels = next(iter(train_dataloader))

Supongo que te gusta

Origin blog.csdn.net/weixin_51799151/article/details/123968508
Recomendado
Clasificación