Directorio de artículos
Conjunto de datos y cargador de datos
1. Explicación oficial (Google Translate):
el código que procesa muestras de datos puede volverse desordenado y difícil de mantener; idealmente, queremos que nuestro código de conjunto de datos esté separado de nuestro código de entrenamiento modelo para una mejor legibilidad y modularidad.
PyTorch proporciona dos primitivas de datos: torch.utils.data.DataLoader
y torch.utils.data.Dataset
que nos permiten usar conjuntos de datos precargados, así como nuestros propios datos. El conjunto de datos almacena las muestras y sus etiquetas correspondientes , y el cargador de datos envuelve un conjunto de datos de objetos iterables para facilitar el acceso a las muestras.
2. Dataset
es una plantilla para todos los conjuntos de datos utilizados por todos los desarrolladores para entrenamiento y pruebas.
Conjunto de datos define el contenido del conjunto de datos, que es equivalente a una estructura de datos similar a una lista con una cierta longitud, y puede usar el índice para obtener los elementos del conjunto de datos.
DataLoader define un método para cargar conjuntos de datos por lotes. Es un objeto iterable que implementa el método __iter__ y genera un lote de datos en cada iteración.
3. DataLoader
DataLoader puede controlar el tamaño del lote, el método de muestreo de elementos en el lote y el método de clasificación de los resultados del lote en el formulario de entrada requerido por el modelo, y puede usar múltiples procesos para leer datos.
En la mayoría de los casos, solo necesitamos implementar el método __len__ y el método __getitem__ de Dataset, puede crear fácilmente su propio conjunto de datos y cargarlo con la canalización de datos predeterminada.
1. Conjunto de datos personalizado
La clase Dataset personalizada debe heredar la clase DataSet oficial de pytorch y también debe implementar tres funciones: __init__ , __len__ y __getitem__ .
init: inicialización (generalmente necesita pasar la ruta del archivo del conjunto de datos , en qué ruta guardar el archivo , función de preprocesamiento )
len: devuelve el tamaño del conjunto de datos
getitem: devuelve las características y etiquetas de la muestra de acuerdo con el índice .
import os.path
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
class MyImageDataset(Dataset):
def __init__(self, annotations_file, data_dir, transform=None, target_transform=None):
# annotations_file:文件路径
# data_dir: 将文件保存到哪个路径
self.data_label = pd.read_csv(annotations_file)
self.data_dir = data_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
# 返回数据集总的大小
return len(self.data_label)
def __getitem__(self, item):
data_name = os.path.join(self.data_dir, self.data_label.iloc[item, 0])
image = read_image(data_name)
# 对特征进行预处理
label = self.data_label.iloc[item, 1]
if self.transform:
image = self.transform(image)
# 对标签进行预处理
if self.target_transform:
label = self.target_transform(label)
return image, label
De hecho, solo necesitamos modificar los cuatro parámetros de annotations_file, data_dir, transform (preprocesamiento de características), target_transform (preprocesamiento de etiquetas) .
El conjunto de datos procesa solo una muestra a la vez y devuelve una característica y la etiqueta correspondiente a la característica
2. Use DataLoaders para preparar datos para el entrenamiento
Recupere las características del conjunto de datos de nuestro conjunto de datos y etiquete una muestra a la vez. Cuando entrenamos un modelo, normalmente queremos pasar muestras en " mini lotes ", reorganizar en cada época (cuántas veces por iteración) para reducir el sobreajuste del modelo y usar el multiprocesamiento de Python para acelerar la recuperación de datos.
batch_size: el número de muestras seleccionadas para un entrenamiento
shuffle = True: ordena aleatoriamente los datos después de cada ciclo de entrenamiento
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
3. Datos iterativos
Hemos cargado este conjunto de datos en el DataLoader y podemos iterar sobre el conjunto de datos según sea necesario . Cada iteración a continuación devuelve un lote de train_features y train_labels (con batch_size=64 características y etiquetas, respectivamente).
El método iter() obtiene un iterador.
El método next() obtiene las funciones y las etiquetas a su vez.
train_features, train_labels = next(iter(train_dataloader))