PyTorch study notes (17) – introdução ao uso do archvision.transforms
Esta postagem de blog é as notas de estudo do PyTorch, o 17º registro de conteúdo, que registra principalmente o uso de archivion.transforms.
Índice
1. A origem do problema
Ao ler o código do aplicativo ResNet, encontrei o seguinte pequeno trecho de código. Esse código aparece antes de ler as informações da imagem. Qual é a função específica desse código? É necessário que os iniciantes descubram o significado específico deste código
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
2. O uso específico de archivision.transforms
Existe um pacote muito importante e útil no framework PyTorch: o archivision, que é composto principalmente de três subpacotes, a saber: archivision.datasets, archivision.models, archivision.transforms. O código acima usa torchvision.transforms
este pacote.
A biblioteca de ferramentas do archvision usada aqui é um pacote de processamento de imagem comumente usado na estrutura pytorch, que pode ser usado para gerar conjuntos de dados de imagem e vídeo (torchvision.datasets), fazer algum pré-processamento de imagem (torchvision.transforms) e importar modelos pré-treinados (torchvision. models) e geração de gráficos e salvamento de imagens (torchvision.utils).
Dentre elas, a função transforms para pré-processar a imagem pode ser: 归一化(normalize)
, 尺寸剪裁(resize)
, 翻转(flip)
etc.
As etapas acima geralmente são uma série de operações reais. Neste momento, a composição pode ser usada para conectar essas operações de pré-processamento de imagem.
Assim como no código acima, a operação aqui é:
transforms.ToTensor() , que converte uma imagem PIL em um tensor. Ou seja, ( H ∗ W ∗ C ) (H\ast W\ast C)( H∗C∗C ) A imagem PIL no intervalo [0,255] é convertida para( C ∗ H ∗ W ) (C\ast H\ast W)( C∗H∗W ) tocha.tensor no intervalo [0,1].
transforms.Normalize([0,485, 0,456, 0,406], [0,229, 0,224, 0,225]) , normaliza a imagem com média [0,485, 0,456, 0,406] e desvio padrão [0,229, 0,224, 0,225].
3. Outros usos do archvision.transforms
Os recursos adicionais da função de transformação incluem:
Redimensionar: Redimensione a imagem fornecida para o tamanho especificado.
ToPILImage: Converte arch.tensor em imagem PIL.
CenterCrop: Use o ponto central da imagem de entrada como o centro para executar a operação de corte do tamanho especificado.
RandomCrop: A operação de recorte do tamanho especificado é realizada em torno da posição aleatória da imagem de entrada.
RandomHorizontalFlip: Inverte a imagem PIL fornecida horizontalmente com 0,5 de probabilidade.
RandomVerticalFlip: Inverte a imagem PIL fornecida verticalmente com uma probabilidade de 0,5.
RandomResizedCrop: corta aleatoriamente a imagem fornecida em diferentes tamanhos e proporções e, em seguida, dimensiona a imagem cortada para o tamanho especificado (com um parâmetro n).
Tons de Cinza: Converte uma determinada imagem em uma imagem em tons de cinza.
RandomGrayscale: Converte uma imagem em uma imagem em tons de cinza com uma probabilidade especificada.
FiveCrop: Corte 5 imagens de um tamanho especificado de uma imagem de entrada, incluindo 4 imagens de canto e um centro.
TenCrop: Recorte 10 imagens do tamanho especificado. O método é inverter a imagem de entrada horizontal ou verticalmente com base no FiveCrop e, em seguida, executar a operação FiveCrop, de modo que uma imagem possa obter 10 imagens cortadas.
Pad: Preenche os pixels de "preenchimento" em todos os lados da imagem fornecida com o valor de "preenchimento".
ColorJitter: Modifique o brilho, contraste, saturação e matiz de uma imagem.
Lambda: Faz a transformação especificada por seus parâmetros.
Para uma introdução detalhada dos quatro pacotes acima e suas funções específicas, consulte a documentação chinesa do Pytorch .
A implementação do código pode se referir à implementação do código do github .
4. Complemente outras funções do módulo archvision
O archvision é uma biblioteca de ferramentas para manipulação de imagens independente do PyTorch. Atualmente, inclui seis módulos:
1) archvision.datasets: Vários conjuntos de dados visuais comumente usados, que podem ser baixados e carregados, e como escrever seu próprio conjunto de dados.
2) Torchvision.models: modelos clássicos, como AlexNet, VGG, ResNet, etc., e parâmetros treinados.
3) Torchvision.transforms: operações de imagem comumente usadas, como corte aleatório, rotação, conversão de tipo de dados, tensor e numpy e troca de imagem PIL, etc.
4) archvision.ops: Fornece algumas operações comumente usadas em CV, como NMS, ROI_Align, ROI_Pool, etc.
5) archvision.io: Fornece algumas operações para entrada e saída, atualmente para gravação e gravação de vídeo.
6) archivision.utils: Outras ferramentas, como gerar uma grade de imagem, etc.
5. Solução de erro de execução
Questão 1: O conjunto de dados é uma imagem colorida e o número de canais é 3, mas o número de canais de entrada no modelo é 1, ou seja, ele recebe uma imagem cinza. Neste momento, um erro será relatado ao treinando o modelo. O erro específico é:
RuntimeError: Given groups=1, weight of size 32 3 3 3, expected input[1, 4, 416, 416] to have 3 channels
Para resolver o problema do número de canais de entrada, ou seja, para modificar a imagem colorida de 3 canais em uma imagem cinza de 1 canal, o método de modificação neste momento é:
修改前:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
transform=torchvision.transforms.torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
修改后:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()]),
download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()]),
download=True)
Ou seja, torchvision.transforms.Grayscale()
a operação de adicionar um.
Questão 1: