Notas de estudo do PyTorch (17) -- introdução ao uso do archvision.transforms

PyTorch study notes (17) – introdução ao uso do archvision.transforms

    Esta postagem de blog é as notas de estudo do PyTorch, o 17º registro de conteúdo, que registra principalmente o uso de archivion.transforms.

1. A origem do problema

    Ao ler o código do aplicativo ResNet, encontrei o seguinte pequeno trecho de código. Esse código aparece antes de ler as informações da imagem. Qual é a função específica desse código? É necessário que os iniciantes descubram o significado específico deste código

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

2. O uso específico de archivision.transforms

    Existe um pacote muito importante e útil no framework PyTorch: o archivision, que é composto principalmente de três subpacotes, a saber: archivision.datasets, archivision.models, archivision.transforms. O código acima usa torchvision.transformseste pacote.

    A biblioteca de ferramentas do archvision usada aqui é um pacote de processamento de imagem comumente usado na estrutura pytorch, que pode ser usado para gerar conjuntos de dados de imagem e vídeo (torchvision.datasets), fazer algum pré-processamento de imagem (torchvision.transforms) e importar modelos pré-treinados (torchvision. models) e geração de gráficos e salvamento de imagens (torchvision.utils).
    Dentre elas, a função transforms para pré-processar a imagem pode ser: 归一化(normalize), 尺寸剪裁(resize), 翻转(flip)etc.
    As etapas acima geralmente são uma série de operações reais. Neste momento, a composição pode ser usada para conectar essas operações de pré-processamento de imagem.
    Assim como no código acima, a operação aqui é:
    transforms.ToTensor() , que converte uma imagem PIL em um tensor. Ou seja, ( H ∗ W ∗ C ) (H\ast W\ast C)( HCC ) A imagem PIL no intervalo [0,255] é convertida para( C ∗ H ∗ W ) (C\ast H\ast W)( CHW ) tocha.tensor no intervalo [0,1].
    transforms.Normalize([0,485, 0,456, 0,406], [0,229, 0,224, 0,225]) , normaliza a imagem com média [0,485, 0,456, 0,406] e desvio padrão [0,229, 0,224, 0,225].

3. Outros usos do archvision.transforms

    Os recursos adicionais da função de transformação incluem:

    Redimensionar: Redimensione a imagem fornecida para o tamanho especificado.

    ToPILImage: Converte arch.tensor em imagem PIL.

    CenterCrop: Use o ponto central da imagem de entrada como o centro para executar a operação de corte do tamanho especificado.

    RandomCrop: A operação de recorte do tamanho especificado é realizada em torno da posição aleatória da imagem de entrada.

    RandomHorizontalFlip: Inverte a imagem PIL fornecida horizontalmente com 0,5 de probabilidade.

    RandomVerticalFlip: Inverte a imagem PIL fornecida verticalmente com uma probabilidade de 0,5.

    RandomResizedCrop: corta aleatoriamente a imagem fornecida em diferentes tamanhos e proporções e, em seguida, dimensiona a imagem cortada para o tamanho especificado (com um parâmetro n).

    Tons de Cinza: Converte uma determinada imagem em uma imagem em tons de cinza.

    RandomGrayscale: Converte uma imagem em uma imagem em tons de cinza com uma probabilidade especificada.

    FiveCrop: Corte 5 imagens de um tamanho especificado de uma imagem de entrada, incluindo 4 imagens de canto e um centro.

    TenCrop: Recorte 10 imagens do tamanho especificado. O método é inverter a imagem de entrada horizontal ou verticalmente com base no FiveCrop e, em seguida, executar a operação FiveCrop, de modo que uma imagem possa obter 10 imagens cortadas.

    Pad: Preenche os pixels de "preenchimento" em todos os lados da imagem fornecida com o valor de "preenchimento".

    ColorJitter: Modifique o brilho, contraste, saturação e matiz de uma imagem.

    Lambda: Faz a transformação especificada por seus parâmetros.

    Para uma introdução detalhada dos quatro pacotes acima e suas funções específicas, consulte a documentação chinesa do Pytorch .

    A implementação do código pode se referir à implementação do código do github .

4. Complemente outras funções do módulo archvision

    O archvision é uma biblioteca de ferramentas para manipulação de imagens independente do PyTorch. Atualmente, inclui seis módulos:

    1) archvision.datasets: Vários conjuntos de dados visuais comumente usados, que podem ser baixados e carregados, e como escrever seu próprio conjunto de dados.

     2) Torchvision.models: modelos clássicos, como AlexNet, VGG, ResNet, etc., e parâmetros treinados.

     3) Torchvision.transforms: operações de imagem comumente usadas, como corte aleatório, rotação, conversão de tipo de dados, tensor e numpy e troca de imagem PIL, etc.

     4) archvision.ops: Fornece algumas operações comumente usadas em CV, como NMS, ROI_Align, ROI_Pool, etc.

     5) archvision.io: Fornece algumas operações para entrada e saída, atualmente para gravação e gravação de vídeo.

     6) archivision.utils: Outras ferramentas, como gerar uma grade de imagem, etc.

5. Solução de erro de execução

    Questão 1: O conjunto de dados é uma imagem colorida e o número de canais é 3, mas o número de canais de entrada no modelo é 1, ou seja, ele recebe uma imagem cinza. Neste momento, um erro será relatado ao treinando o modelo. O erro específico é:

RuntimeError: Given groups=1, weight of size 32 3 3 3, expected input[1, 4, 416, 416] to have 3 channels

    Para resolver o problema do número de canais de entrada, ou seja, para modificar a imagem colorida de 3 canais em uma imagem cinza de 1 canal, o método de modificação neste momento é:

修改前:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
                                          transform=torchvision.transforms.torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=True)
修改后:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
                                          transform=torchvision.transforms.Compose([
                                              torchvision.transforms.Grayscale(),
                                              torchvision.transforms.ToTensor()]),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
                                         transform=torchvision.transforms.Compose([
                                             torchvision.transforms.Grayscale(),
                                             torchvision.transforms.ToTensor()]),
                                         download=True)

    Ou seja, torchvision.transforms.Grayscale()a operação de adicionar um.
    Questão 1:

Acho que você gosta

Origin blog.csdn.net/weixin_43981621/article/details/121695174
Recomendado
Clasificación