Prática de aplicação 1 do PyTorch: Implementando a operação de convolução

ambiente de laboratório

python3.6 + pytorch1.8.0

import torch
print(torch.__version__)
1.8.0

0.Definição de convolução

A operação de convolução refere-se a uma operação matemática entre duas funções f e G. É amplamente utilizada em processamento de sinais, processamento de imagens, aprendizado de máquina e outros campos. No caso discreto, a operação de convolução pode ser expressa como:

( f ∗ g ) [ n ] = ∑ m = − ∞ ∞ f [ m ] g [ n − m ] (f * g)[n] = \sum_{m=-\infty}^{\infty}f[ m]g[nm]( fg ) [ n ]=m = f [ m ] g [ n-eu ]

Entre eles, fff eggg é uma função discreta,∗ * representa a operação de convolução,nnn é uma variável discreta. A operação de convolução pode ser vista como a conversão da funçãoggvamos juntonninversão do eixo n , depois translação, cada vez e função fff é multiplicado e somado e finalmente obtemos uma nova função. Esta operação pode realizar funções como filtragem de sinal e extração de recursos, e é uma operação básica muito importante no processamento de sinal digital.

1. Use operações de tensor para implementar convolução

1.1 função de desdobramento

As funções PyTorch unfoldsão usadas para expandir tensores. torch.unfold()Pode ser entendido como a operação de expansão de um tensor de alta dimensão em uma matriz bidimensional. Ou seja, o tensor original é expandido em uma matriz bidimensional ao longo das dimensões especificadas, onde a primeira dimensão corresponde à dimensão do tensor original e a segunda dimensão corresponde à posição expandida.

O protótipo da função é o seguinte:

torch.unfold(input, dimension, size, step)

Descrição do parâmetro:

  • input (Tensor) – o tensor para desembrulhar
  • dimensão (int) – A dimensão ao longo da qual expandir
  • size (int) – tamanho da janela expandida
  • step (int) – o tamanho do passo entre duas janelas adjacentes

1.2 Fragmentação de tensor

import torch
a = torch.arange(16).view(4, 4)
a
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
b = a.unfold(0, 3, 1)
b
tensor([[[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]],

        [[ 4,  8, 12],
         [ 5,  9, 13],
         [ 6, 10, 14],
         [ 7, 11, 15]]])
b.shape
torch.Size([2, 4, 3])
c = b.unfold(1, 3, 1)
c
tensor([[[[ 0,  1,  2],
          [ 4,  5,  6],
          [ 8,  9, 10]],

         [[ 1,  2,  3],
          [ 5,  6,  7],
          [ 9, 10, 11]]],


        [[[ 4,  5,  6],
          [ 8,  9, 10],
          [12, 13, 14]],

         [[ 5,  6,  7],
          [ 9, 10, 11],
          [13, 14, 15]]]])
c.shape
torch.Size([2, 2, 3, 3])

programa completo

import torch
a = torch.arange(16).view(4, 4)
b = a.unfold(0, 3, 1)
c = b.unfold(1, 3, 1)
c.shape
torch.Size([2, 2, 3, 3])

Este código define três variáveis. Suponha que os nomeemos respectivamente a, be c, então:

  • A variável aé um tensor 4x4 contendo valores inteiros de 0 a 15, que é implementado por torch.arange(16)duas .view(4, 4)chamadas de função.
  • A variável é um tensor obtido bdobrando a variável . Especificamente , é obtido expandindo a variável ao longo da 0ª dimensão (ou seja, linha) e tomando um subtensor com tamanho de janela de 3 e tamanho de passo de 1. resultado. Então, se imprimirmos o tensor , obteremos:aab
tensor([[[ 0,  1,  2],
         [ 4,  5,  6],
         [ 8,  9, 10],
         [12, 13, 14]],

        [[ 1,  2,  3],
         [ 5,  6,  7],
         [ 9, 10, 11],
         [13, 14, 15]]])

Entre eles, o valor do primeiro subtensor é [[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]]e o valor do segundo subtensor é [[1, 2, 3], [5, 6, 7], [9, 10, 11], [13, 14, 15]]. Observe que a forma deste tensor é (2, 4, 3), ou seja, ele contém 2 subtensores, cada um deles com forma (4, 3).

  • A variável cé bobtida realizando uma operação semelhante na variável, mas é expandida na 1ª dimensão (ou seja, coluna) e assume um subtensor. Especificamente, é o resultado da expansão da variável bao longo da 1ª dimensão (ou seja, coluna) e da obtenção de um subtensor com tamanho de janela 3 e passo 1. Então, se imprimirmos o tensor c, obteremos:
tensor([[[[ 0,  1,  2],
          [ 4,  5,  6],
          [ 8,  9, 10]],

         [[ 1,  2,  3],
          [ 5,  6,  7],
          [ 9, 10, 11]]],


        [[[ 4,  5,  6],
          [ 8,  9, 10],
          [12, 13, 14]],

         [[ 5,  6,  7],
          [ 9, 10, 11],
          [13, 14, 15]]]])

Entre eles, o valor do primeiro subtensor é [[[0, 1, 2], [4, 5, 6], [8, 9, 10]], [[1, 2, 3], [5, 6, 7], [9, 10, 11]]]e o valor do segundo subtensor é [[[4, 5, 6], [8, 9, 10], [12, 13, 14]], [[5, 6, 7], [9, 10, 11], [13, 14, 15]]]. Observe que a forma deste tensor é (2, 2, 3, 3), ou seja, ele contém 2 subtensores, cada um deles com forma (2, 3, 3).

2. Implementar operação de convolução

2.1 Escrever função de convolução

programa completo

import torch
def conv2d(x, weight, bias, stride, pad):
    n, c, h, w = x.shape
    d, c, k, j = weight.shape
    
    x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
    x_pad[:, :, pad:-pad, pad:-pad] = x
    
    x_pad = x_pad.unfold(2, k, stride)
    x_pad = x_pad.unfold(3, j, stride)
    
    out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
    out = out + bias.view(1, -1, 1, 1)
    return out

Esta função implementa uma operação de convolução bidimensional. A função é analisada em detalhes abaixo:

  1. Parâmetros de entrada:
  • x: Tensor de entrada com dimensões (batch_size, in_channels, input_height, input_width).
  • peso: tensor de kernel de convolução com dimensões (out_channels, in_channels, kernel_height, kernel_width).
  • polarização: Tensor de termo de polarização com dimensões (out_channels,).
  • Stride: O tamanho do passo do movimento do kernel de convolução, que pode ser um número ou uma tupla de comprimento 2, indicando o tamanho do passo nas direções horizontal e vertical, respectivamente.
  • pad: O número de zeros a serem preenchidos ao redor do tensor de entrada.
  1. Preenchimento parcial:
  • Antes de realizar a operação de convolução, o tensor de entrada precisa ser preenchido de acordo com o bloco fornecido para evitar que o kernel de convolução saia do intervalo na borda do tensor.
  • Use x_pad na função para representar o tensor de entrada preenchido.
  • Implementação específica: Divida o tensor de entrada x em vários tensores com formas de (kernel_height, kernel_width)nas 2ª e 3ª dimensões (dimensões de altura e largura),respectivamente.O comprimento do salto entre cada tensor é dado por A passada é determinada e depois expandida em a 2ª e 3ª dimensões respectivamente. Desta forma, cada tensor expandido pode ser considerado como o resultado da convolução local de um núcleo de convolução bidimensional atuando em x. Esses resultados locais são emendados novamente de acordo com a 2ª e 3ª dimensões para obter um novo tensor x_pad.
  1. Operação de convolução:
  • Use a função einsum para realizar uma operação de convolução no novo tensor x_pad.

  • O primeiro parâmetro de einsum representa a regra de operação, onde ndhw representa as dimensões do tensor de saída final (batch_size, out_channels, output_height, output_width), nchw e dckj representam as dimensões dos dois tensores de entrada x_pad e peso, onde ckj respectivamente representa input_channels, kernel_height e kernel_width.

  • A forma final do tensor de saída é (batch_size, out_channels, output_height, output_width), e um termo de polarização de polarização é adicionado a cada posição.

  • torch: Biblioteca PyTorch

  • einsum: Notação de soma de Einstein, convenção de soma de Einstein, uma notação simples para soma de tensores.

  • 'nchwkj,dckj->ndhw': Notação de soma de Einstein, o tensor da esquerda é x_pade o tensor da direita é weight. No tensor à esquerda, n, c, representam respectivamente o tamanho do lote, número de canais, altura e largura h. wNo tensor à direita,,,, drepresentam respectivamente o número de canais de saída, o número de canais de entrada, a altura do kernel de convolução e a largura do kernel de convolução c. O significado desta fórmula é realizar uma operação de convolução na soma e gerar o tensor de resultado, cuja forma é .kjx_padweight(batch_size, output_channels, height, width)

  • x_pad: Tensor de entrada com forma (batch_size, input_channels, input_height, input_width).

  • weight: Tensor de kernel de convolução com forma (output_channels, input_channels, kernel_height, kernel_width).

  1. Resultados de retorno:
  • Retorna o tensor de saída obtido após a operação de convolução.

2.2 Análise de exemplos de funções de convolução escritas

# 设置测试数据
x = torch.randn(2, 3, 5, 5, requires_grad=True)
weight = torch.randn(4, 3, 3, 3, requires_grad=True)
bias = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
x, weight, bias
(tensor([[[[-0.4888,  1.0257,  0.0312, -0.9026, -0.9060],
           [ 0.2071, -0.4962, -0.1658,  1.0919,  0.3785],
           [-0.4654,  1.5442,  0.6005,  0.3594, -2.6207],
           [ 0.5830,  0.0533,  0.5719,  1.5413,  0.5949],
           [-0.9152, -0.2114, -0.4888, -0.0065, -0.9767]],
 
          [[ 0.4706, -0.1108, -0.1563, -1.7946, -0.8533],
           [-0.2119,  0.3165, -2.2668, -0.8956,  1.0617],
           [-0.7809, -0.2120, -0.8592, -0.5057,  0.7954],
           [-2.8820, -0.6888,  0.4450, -0.3586, -0.9477],
           [ 0.6244,  0.4303,  1.4739,  0.2740,  1.6605]],
 
          [[-0.1501,  0.6234, -1.6086,  0.1693,  0.4932],
           [ 1.0611, -1.0938,  0.1695,  1.0193,  0.4263],
           [ 1.4681, -0.1552, -0.0667, -0.7293,  1.0816],
           [ 0.8972,  1.1683, -1.4757,  0.4421, -0.0355],
           [-2.1331,  1.4847,  0.1378, -1.6907, -0.1350]]],
 
 
         [[[-1.3853,  1.6396,  0.3436,  0.3841,  0.2355],
           [-0.2206, -0.5087, -1.6956,  1.3205,  0.7058],
           [ 0.0993,  0.3533, -0.2086,  0.2969,  0.2627],
           [ 0.3752,  0.0304,  1.2487,  1.3963, -0.0063],
           [-1.3758,  0.5088, -1.3849,  1.3050,  0.4150]],
 
          [[ 0.2824, -2.8634, -0.1016, -0.1627,  1.7081],
           [ 0.1406,  0.2220, -0.6005,  0.2997, -0.1846],
           [ 1.6700,  0.5787,  0.6561, -0.0236,  1.7743],
           [ 2.1429, -0.2838, -0.0527,  0.3504, -0.3444],
           [-0.9409, -0.4734, -0.4060, -0.5088, -1.8518]],
 
          [[-2.2152,  0.2104, -0.3302,  0.2036, -0.9443],
           [-0.6576, -0.4455,  0.5117, -2.0058, -1.3985],
           [-0.5688,  1.2338, -0.1832,  0.1760,  0.4506],
           [-0.6563,  0.4021, -1.6210,  0.5582, -0.9238],
           [-1.0506, -0.9638,  0.7453, -0.3535, -0.3536]]]], requires_grad=True),
 tensor([[[[ 0.3069,  0.2079, -0.2952],
           [ 1.7681,  1.1056, -1.0555],
           [ 1.5845,  0.8294,  0.6588]],
 
          [[ 0.2574,  0.5007,  0.2912],
           [-0.0210,  0.6593, -0.9691],
           [-0.2918,  0.5695, -1.1242]],
 
          [[ 0.7327, -0.3453,  0.7041],
           [-0.2236, -1.7762,  0.0190],
           [-1.0927, -2.9369,  0.1768]]],
 
 
         [[[-2.3830, -1.4807,  1.8573],
           [ 1.0097, -0.9640,  1.0361],
           [-0.5222, -1.0386, -0.4016]],
 
          [[ 0.5071,  1.1433, -0.1194],
           [-0.0133, -0.3878, -0.1853],
           [ 0.3456, -0.6502,  0.2221]],
 
          [[-1.7672, -0.0469, -0.5996],
           [-0.2080, -1.6209,  0.4120],
           [ 0.8404, -1.6748, -0.7170]]],
 
 
         [[[ 0.2850,  0.1691, -0.9228],
           [ 0.7234,  0.5582, -0.4327],
           [ 0.6563,  0.2941,  1.5549]],
 
          [[ 0.2642, -1.9061,  1.6212],
           [-0.5276, -0.5608,  0.3824],
           [ 0.4452, -2.5152,  0.4490]],
 
          [[-0.1276,  0.7784,  0.7998],
           [-0.3030, -0.9776,  0.9681],
           [ 1.0225,  0.8946, -0.8084]]],
 
 
         [[[-0.5087, -0.8345, -1.4763],
           [-0.4938,  1.1979, -0.1335],
           [ 0.5010,  0.2865,  0.0728]],
 
          [[-0.3177, -0.6937, -1.0327],
           [ 0.8147, -1.7101, -1.8257],
           [-0.1593, -1.3855, -0.0885]],
 
          [[-0.4687, -1.6307,  1.5791],
           [-1.3030,  0.2004, -0.7055],
           [ 0.0674, -0.8772,  0.1586]]]], requires_grad=True),
 tensor([ 1.5349, -0.5608,  0.5182,  0.3328], requires_grad=True))
n, c, h, w = x.shape
d, c, k, j = weight.shape
n, c, h, w
(2, 3, 5, 5)
d, c, k, j
(4, 3, 3, 3)
# 补零
x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
x_pad
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
x_pad.shape
torch.Size([2, 3, 9, 9])
x_pad[:, :, pad:-pad, pad:-pad] = x
x_pad
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.4888,  1.0257,  0.0312, -0.9026, -0.9060,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.2071, -0.4962, -0.1658,  1.0919,  0.3785,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.4654,  1.5442,  0.6005,  0.3594, -2.6207,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.5830,  0.0533,  0.5719,  1.5413,  0.5949,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.9152, -0.2114, -0.4888, -0.0065, -0.9767,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.4706, -0.1108, -0.1563, -1.7946, -0.8533,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.2119,  0.3165, -2.2668, -0.8956,  1.0617,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.7809, -0.2120, -0.8592, -0.5057,  0.7954,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.8820, -0.6888,  0.4450, -0.3586, -0.9477,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.6244,  0.4303,  1.4739,  0.2740,  1.6605,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.1501,  0.6234, -1.6086,  0.1693,  0.4932,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.0611, -1.0938,  0.1695,  1.0193,  0.4263,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.4681, -0.1552, -0.0667, -0.7293,  1.0816,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.8972,  1.1683, -1.4757,  0.4421, -0.0355,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.1331,  1.4847,  0.1378, -1.6907, -0.1350,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.3853,  1.6396,  0.3436,  0.3841,  0.2355,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.2206, -0.5087, -1.6956,  1.3205,  0.7058,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0993,  0.3533, -0.2086,  0.2969,  0.2627,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.3752,  0.0304,  1.2487,  1.3963, -0.0063,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.3758,  0.5088, -1.3849,  1.3050,  0.4150,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.2824, -2.8634, -0.1016, -0.1627,  1.7081,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.1406,  0.2220, -0.6005,  0.2997, -0.1846,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.6700,  0.5787,  0.6561, -0.0236,  1.7743,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  2.1429, -0.2838, -0.0527,  0.3504, -0.3444,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.9409, -0.4734, -0.4060, -0.5088, -1.8518,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.2152,  0.2104, -0.3302,  0.2036, -0.9443,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.6576, -0.4455,  0.5117, -2.0058, -1.3985,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.5688,  1.2338, -0.1832,  0.1760,  0.4506,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.6563,  0.4021, -1.6210,  0.5582, -0.9238,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.0506, -0.9638,  0.7453, -0.3535, -0.3536,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]]]], grad_fn=<CopySlices>)
# 卷积
x_pad = x_pad.unfold(2, k, stride)
x_pad.shape
torch.Size([2, 3, 4, 9, 3])
x_pad = x_pad.unfold(3, j, stride)
x_pad.shape
torch.Size([2, 3, 4, 4, 3, 3])
out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
out.shape
torch.Size([2, 4, 4, 4])
bias.view(1, -1, 1, 1).shape
torch.Size([1, 4, 1, 1])
# 偏置
out = out + bias.view(1, -1, 1, 1)
out
tensor([[[[ 0.6573, -0.3444,  1.5693, -0.1906],
          [ 2.5483,  5.1142, -2.3528, -3.6162],
          [ 2.9913, -6.1289,  6.8200,  0.9229],
          [ 0.4849,  0.1813,  3.2616,  1.5637]],

         [[-0.1524, -1.2003, -0.3415,  0.0318],
          [-1.7830,  2.5286, -1.6660,  3.1253],
          [ 1.3314, -8.2623, -5.0055,  5.7671],
          [-1.0563,  5.2751, -0.4214,  2.8473]],

         [[ 0.0908,  2.6704,  1.0336,  0.0481],
          [ 0.2077,  2.0459,  1.8095, -0.7039],
          [ 0.9519, -4.5551,  3.7108,  0.7446],
          [ 0.6689,  3.9448,  2.3968,  0.6958]],

         [[ 0.2318, -0.3356,  2.4320,  0.0480],
          [ 0.2101, -1.7177,  6.3956, -0.4108],
          [ 8.2352, -5.8456, 12.9459, -0.8763],
          [-2.3292, -1.5263,  2.1349,  0.3653]]],


        [[[-0.0869,  1.0713,  0.1655,  2.4414],
          [-1.3623, -3.0759,  0.2430,  2.3259],
          [-0.9281,  2.1402,  7.1618,  4.1895],
          [ 0.9273,  1.1176,  0.7792,  0.9265]],

         [[ 1.6465, -1.7187, -0.7251, -0.8871],
          [-1.6260, -0.8628, -1.0122,  3.2737],
          [ 0.5831,  2.1665, -0.5353, -2.0468],
          [-2.3738, -0.1232, -0.0771, -1.8642]],

         [[ 0.2819,  6.0978,  2.9618,  0.4676],
          [ 1.3592,  6.7231,  3.8100,  3.6118],
          [ 0.9885, -5.7760,  5.4375,  0.5480],
          [-0.5778,  1.4657, -2.8315,  0.1923]],

         [[-0.1444,  3.6788,  0.3721,  0.1150],
          [-1.4057,  0.1613, -2.5436,  1.3156],
          [-6.1195,  1.8325,  3.1565,  0.8296],
          [ 1.6766,  6.9403,  1.3986,  0.8758]]]], grad_fn=<AddBackward0>)

2.3 Verifique a exatidão da escrita da função de convolução

import torch.nn.functional as F
x = torch.randn(2, 3, 5, 5, requires_grad=True)
w = torch.randn(4, 3, 3, 3, requires_grad=True)
b = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
torch_out = F.conv2d(x, w, b, stride, pad)
my_out = conv2d(x, w, b, stride, pad)
torch_out == my_out
tensor([[[[ True,  True,  True,  True],
          [ True, False, False,  True],
          [ True,  True, False,  True],
          [ True,  True, False,  True]],

         [[ True,  True,  True,  True],
          [ True, False,  True,  True],
          [False,  True,  True,  True],
          [ True,  True,  True,  True]],

         [[ True, False, False,  True],
          [False, False,  True,  True],
          [ True, False, False,  True],
          [ True,  True, False,  True]],

         [[ True,  True, False,  True],
          [False, False, False,  True],
          [ True, False, False, False],
          [ True,  True,  True,  True]]],


        [[[ True,  True,  True,  True],
          [False,  True, False,  True],
          [ True,  True, False,  True],
          [ True, False,  True,  True]],

         [[ True,  True, False,  True],
          [ True, False, False, False],
          [ True, False, False, False],
          [ True, False,  True,  True]],

         [[ True,  True, False,  True],
          [False, False, False,  True],
          [ True, False, False, False],
          [ True, False,  True,  True]],

         [[ True, False,  True,  True],
          [ True, False, False,  True],
          [False, False, False, False],
          [ True,  True, False,  True]]]])
torch.allclose(torch_out, my_out, atol=1e-5)
True
  • torch.allclose é uma função usada para verificar se os valores entre dois tensores são iguais.

  • Ao usá-lo, você precisa passar o primeiro tensor como o primeiro parâmetro (ou seja, torch_out), o segundo tensor como o segundo parâmetro (ou seja, my_out) e o erro absoluto permitido (atol) Passado como o terceiro argumento (o padrão é 1e-8).

  • A função retornará um valor booleano indicando se os dois tensores possuem valores semelhantes. Se True for retornado, significa que os dois tensores possuem valores numéricos semelhantes, caso contrário significa que há uma diferença numérica entre eles.

grad_out = torch.randn(*torch_out.shape)
grad_x = torch.autograd.grad(torch_out, x, grad_out, retain_graph=True)
my_grad_x = torch.autograd.grad(my_out, x, grad_out, retain_graph=True)
torch.allclose(grad_x[0], my_grad_x[0], atol=1e-5)
True
grad_w = torch.autograd.grad(torch_out, w, grad_out, retain_graph=True)
my_grad_w = torch.autograd.grad(my_out, w, grad_out, retain_graph=True)
torch.allclose(grad_w[0], my_grad_w[0], atol=1e-5)
True
grad_b = torch.autograd.grad(torch_out, b, grad_out, retain_graph=True)
my_grad_b = torch.autograd.grad(my_out, b, grad_out, retain_graph=True)
torch.allclose(grad_b[0], my_grad_b[0], atol=1e-5)
True

Todos são verdadeiros, indicando que a função de convolução escrita tem resultados semelhantes à função Conv2d integrada do PyTorch dentro de um determinado intervalo, indicando a correção da implementação.

Anexo: série de artigos

número de série Diretório de artigos link direto
1 Prática de aplicação 1 do PyTorch: Implementando a operação de convolução https://want595.blog.csdn.net/article/details/132575530
2 Prática de aplicação 2 do PyTorch: implementação de rede neural convolucional para classificação de imagens https://want595.blog.csdn.net/article/details/132575702
3 Prática de aplicação 3 do PyTorch: Construindo uma rede neural https://want595.blog.csdn.net/article/details/132575758
4 Prática quatro do aplicativo PyTorch: construção de aplicativos complexos baseados em PyTorch https://want595.blog.csdn.net/article/details/132625270
5 Prática de aplicação cinco do PyTorch: implementação de redes neurais binárias https://want595.blog.csdn.net/article/details/132625348
6 Prática de aplicação 6 do PyTorch: Usando LSTM para implementar classificação de emoções de texto https://want595.blog.csdn.net/article/details/132625382

Acho que você gosta

Origin blog.csdn.net/m0_68111267/article/details/132575530
Recomendado
Clasificación