Notas de estudio de PyTorch (17) -- introducción al uso de torchvision.transforms

Notas de estudio de PyTorch (17): introducción al uso de torchvision.transforms

    Esta publicación de blog son las notas de estudio de PyTorch, el registro de contenido número 17, que registra principalmente el uso de torchvision.transforms.

1. El origen del problema

    Al leer el código de la aplicación de ResNet, encontré el siguiente pequeño fragmento de código. Este código aparece antes de leer la información de la imagen. ¿Cuál es la función específica de este código? Es necesario que los principiantes descubran el significado específico de este código.

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

2. El uso específico de torchvision.transforms

    Hay un paquete muy importante y útil en el marco PyTorch: torchvision, que se compone principalmente de tres subpaquetes, a saber: torchvision.datasets, torchvision.models, torchvision.transforms. El código anterior usa torchvision.transformseste paquete.

    La biblioteca de herramientas de torchvision que se usa aquí es un paquete de procesamiento de imágenes de uso común en el marco de pytorch, que se puede usar para generar conjuntos de datos de imágenes y videos (torchvision.datasets), hacer un preprocesamiento de imágenes (torchvision.transforms) e importar modelos pre-entrenados. (torchvision. models), y generar gráficos y guardar imágenes (torchvision.utils).
    Entre ellas, la función transforms para preprocesar la imagen puede ser: 归一化(normalize), 尺寸剪裁(resize), 翻转(flip)etc.
    Los pasos anteriores suelen ser una serie de operaciones reales. En este momento, se puede usar Compose para conectar estas operaciones de preprocesamiento de imágenes.
    Como en el código anterior, la operación aquí es:
    transforms.ToTensor() , que convierte una imagen PIL en un tensor. Es decir, ( H ∗ W ∗ C ) (H\ast W\ast C)( HWC ) La imagen PIL en el rango [0,255] se convierte en( C ∗ H ∗ W ) (C\ast H\ast W)( CHW ) antorcha.tensor en el rango [0,1].
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), normaliza la imagen con media [0.485, 0.456, 0.406] y desviación estándar [0.229, 0.224, 0.225].

3. Otros usos de torchvision.transforms

    Las características adicionales de la función de transformación incluyen:

    Redimensionar: Redimensionar la imagen dada al tamaño dado.

    ToPILImage: Convierte torch.tensor a imagen PIL.

    CenterCrop: utilice el punto central de la imagen de entrada como centro para realizar la operación de recorte del tamaño especificado.

    RandomCrop: la operación de recorte del tamaño especificado se realiza alrededor de la posición aleatoria de la imagen de entrada.

    RandomHorizontalFlip: voltea la imagen PIL dada horizontalmente con una probabilidad de 0,5.

    RandomVerticalFlip: voltea la imagen PIL dada verticalmente con una probabilidad de 0.5.

    RandomResizedCrop: recorta aleatoriamente la imagen dada a diferentes tamaños y relaciones de aspecto, y luego escala la imagen recortada al tamaño especificado (con un parámetro n).

    Escala de grises: convierte una imagen determinada en una imagen en escala de grises.

    RandomGrayscale: convierte una imagen en una imagen en escala de grises con una probabilidad especificada.

    FiveCrop: Recorta 5 imágenes de un tamaño específico a partir de una imagen de entrada, incluidas 4 imágenes de esquina y un centro.

    TenCrop: Recorta 10 imágenes del tamaño especificado. El método es voltear la imagen de entrada horizontal o verticalmente sobre la base de FiveCrop y luego realizar la operación FiveCrop, de modo que una imagen pueda obtener 10 imágenes recortadas.

    Relleno: Rellena los píxeles de "relleno" en todos los lados de la imagen dada con el valor de "relleno".

    ColorJitter: modifica el brillo, el contraste, la saturación y el tono de una imagen.

    Lambda: Realiza la transformación especificada por sus parámetros.

    Para obtener una introducción detallada de los cuatro paquetes anteriores y sus funciones específicas, consulte la documentación en chino de Pytorch .

    La implementación del código puede referirse a la implementación del código de github .

4. Complementar otras funciones del módulo torchvision

    torchvision es una biblioteca de herramientas para la manipulación de imágenes independiente de PyTorch que actualmente incluye seis módulos:

    1) torchvision.datasets: varios conjuntos de datos visuales de uso común, que se pueden descargar y cargar, y cómo escribir su propio conjunto de datos.

     2) torchvision.models: modelos clásicos, como AlexNet, VGG, ResNet, etc., y parámetros entrenados.

     3) torchvision.transforms: operaciones de imagen de uso común, como corte aleatorio, rotación, conversión de tipo de datos, tensor y numpy e intercambio de imágenes PIL, etc.

     4) torchvision.ops: proporciona algunas operaciones de uso común en CV, como NMS, ROI_Align, ROI_Pool, etc.

     5) torchvision.io: proporciona algunas operaciones para entrada y salida, actualmente para escritura y escritura de video.

     6) torchvision.utils: Otras herramientas, como generar una grilla de imágenes, etc.

5. Ejecutar la resolución de errores

    Pregunta 1: el conjunto de datos es una imagen en color, la cantidad de canales es 3, pero la cantidad de canales de entrada en el modelo es 1, es decir, se recibe la imagen gris.En este momento, se informará un error al entrenar el modelo El error específico es:

RuntimeError: Given groups=1, weight of size 32 3 3 3, expected input[1, 4, 416, 416] to have 3 channels

    Para resolver el problema del número de canales de entrada, es decir, para modificar la imagen de color de 3 canales en una imagen gris de 1 canal, el método de modificación en este momento es:

修改前:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
                                          transform=torchvision.transforms.torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=True)
修改后:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
                                          transform=torchvision.transforms.Compose([
                                              torchvision.transforms.Grayscale(),
                                              torchvision.transforms.ToTensor()]),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
                                         transform=torchvision.transforms.Compose([
                                             torchvision.transforms.Grayscale(),
                                             torchvision.transforms.ToTensor()]),
                                         download=True)

    Es decir, torchvision.transforms.Grayscale()la operación de sumar uno.
    Pregunta 1:

Supongo que te gusta

Origin blog.csdn.net/weixin_43981621/article/details/121695174
Recomendado
Clasificación