Pratique d'application PyTorch 1 : implémentation de l'opération de convolution

environnement de laboratoire

python3.6 + pytorch1.8.0

import torch
print(torch.__version__)
1.8.0

0.Définition de convolution

L'opération de convolution fait référence à une opération mathématique entre deux fonctions f et g. Elle est largement utilisée dans le traitement du signal, le traitement d'images, l'apprentissage automatique et d'autres domaines. Dans le cas discret, l’opération de convolution peut s’exprimer comme suit :

( f ∗ g ) [ n ] = ∑ m = − ∞ ∞ f [ m ] g [ n − m ] (f * g)[n] = \sum_{m=-\infty}^{\infty}f[ m]g[nm]( f*g ) [ n ]=m = f [ m ] g [ nm ]

Parmi eux, fff etggg est une fonction discrète,∗ * représente l'opération de convolution,nnn est une variable discrète. L'opération de convolution peut être considérée comme la conversion de la fonctionggje suis avec toinnRetournement de l'axe n , puis translation, à chaque fois et fonction fff est multiplié et additionné, et finalement nous obtenons une nouvelle fonction. Cette opération peut réaliser des fonctions telles que le filtrage du signal et l'extraction de caractéristiques, et constitue une opération de base très importante dans le traitement du signal numérique.

1. Utiliser des opérations tensorielles pour implémenter la convolution

1.1 fonction de dépliage

Les fonctions PyTorch unfoldsont utilisées pour développer les tenseurs. torch.unfold()Cela peut être compris comme l’opération consistant à étendre un tenseur de grande dimension en une matrice bidimensionnelle. Autrement dit, le tenseur d'origine est développé en une matrice bidimensionnelle le long des dimensions spécifiées, où la première dimension correspond à la dimension du tenseur d'origine et la deuxième dimension correspond à la position développée.

Le prototype de la fonction est le suivant :

torch.unfold(input, dimension, size, step)

Description du paramètre :

  • input (Tensor) – le tenseur à déballer
  • dimension (int) – La dimension le long de laquelle développer
  • size (int) – taille de la fenêtre développée
  • step (int) – la taille du pas entre deux fenêtres adjacentes

1.2 Partage de tenseur

import torch
a = torch.arange(16).view(4, 4)
a
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
b = a.unfold(0, 3, 1)
b
tensor([[[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]],

        [[ 4,  8, 12],
         [ 5,  9, 13],
         [ 6, 10, 14],
         [ 7, 11, 15]]])
b.shape
torch.Size([2, 4, 3])
c = b.unfold(1, 3, 1)
c
tensor([[[[ 0,  1,  2],
          [ 4,  5,  6],
          [ 8,  9, 10]],

         [[ 1,  2,  3],
          [ 5,  6,  7],
          [ 9, 10, 11]]],


        [[[ 4,  5,  6],
          [ 8,  9, 10],
          [12, 13, 14]],

         [[ 5,  6,  7],
          [ 9, 10, 11],
          [13, 14, 15]]]])
c.shape
torch.Size([2, 2, 3, 3])

programme complet

import torch
a = torch.arange(16).view(4, 4)
b = a.unfold(0, 3, 1)
c = b.unfold(1, 3, 1)
c.shape
torch.Size([2, 2, 3, 3])

Ce code définit trois variables. Supposons que nous les nommions respectivement a, bet c, alors :

  • La variable aest un tenseur 4x4 contenant des valeurs entières de 0 à 15, qui est implémenté par torch.arange(16)et .view(4, 4)deux appels de fonction.
  • La variable est un tenseur obtenu ben pliant la variable . Plus précisément, il est obtenu en développant la variable le long de la 0ème dimension (c'est-à-dire la ligne) et en prenant un sous-tenseur avec une taille de fenêtre de 3 et un pas de 1. résultat. Donc si on imprime le tenseur , on obtient :aab
tensor([[[ 0,  1,  2],
         [ 4,  5,  6],
         [ 8,  9, 10],
         [12, 13, 14]],

        [[ 1,  2,  3],
         [ 5,  6,  7],
         [ 9, 10, 11],
         [13, 14, 15]]])

Parmi eux, la valeur du premier sous-tenseur est [[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]], et la valeur du deuxième sous-tenseur est [[1, 2, 3], [5, 6, 7], [9, 10, 11], [13, 14, 15]]. Notez que la forme de ce tenseur est (2, 4, 3), c'est-à-dire qu'il contient 2 sous-tenseurs dont chacun a une forme (4, 3).

  • La variable cest bobtenue en effectuant une opération similaire sur la variable, mais elle est développée sur la 1ère dimension (c'est-à-dire la colonne) et prend un sous-tenseur. Plus précisément, c'est le résultat de l'expansion de la variable ble long de la 1ère dimension (c'est-à-dire la colonne) et de la prise d'un sous-tenseur avec une taille de fenêtre de 3 et une foulée de 1. Donc si on imprime le tenseur c, on obtient :
tensor([[[[ 0,  1,  2],
          [ 4,  5,  6],
          [ 8,  9, 10]],

         [[ 1,  2,  3],
          [ 5,  6,  7],
          [ 9, 10, 11]]],


        [[[ 4,  5,  6],
          [ 8,  9, 10],
          [12, 13, 14]],

         [[ 5,  6,  7],
          [ 9, 10, 11],
          [13, 14, 15]]]])

Parmi eux, la valeur du premier sous-tenseur est [[[0, 1, 2], [4, 5, 6], [8, 9, 10]], [[1, 2, 3], [5, 6, 7], [9, 10, 11]]], et la valeur du deuxième sous-tenseur est [[[4, 5, 6], [8, 9, 10], [12, 13, 14]], [[5, 6, 7], [9, 10, 11], [13, 14, 15]]]. Notez que la forme de ce tenseur est (2, 2, 3, 3), c'est-à-dire qu'il contient 2 sous-tenseurs dont chacun a une forme (2, 3, 3).

2. Implémenter l'opération de convolution

2.1 Écrire la fonction de convolution

programme complet

import torch
def conv2d(x, weight, bias, stride, pad):
    n, c, h, w = x.shape
    d, c, k, j = weight.shape
    
    x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
    x_pad[:, :, pad:-pad, pad:-pad] = x
    
    x_pad = x_pad.unfold(2, k, stride)
    x_pad = x_pad.unfold(3, j, stride)
    
    out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
    out = out + bias.view(1, -1, 1, 1)
    return out

Cette fonction implémente une opération de convolution bidimensionnelle. La fonction est analysée en détail ci-dessous :

  1. Paramètres d'entrée:
  • x : Tenseur d'entrée avec dimensions (batch_size, in_channels, input_height, input_width).
  • poids : tenseur du noyau de convolution avec des dimensions (out_channels, in_channels, kernel_height, kernel_width).
  • biais : tenseur de terme de biais avec des dimensions (out_channels,).
  • Stride : la taille du pas du mouvement du noyau de convolution, qui peut être un nombre ou un tuple de longueur 2, indiquant la taille du pas dans les directions horizontale et verticale respectivement.
  • pad : Le nombre de zéros à compléter autour du tenseur d'entrée.
  1. Remplissage partiel :
  • Avant d'effectuer l'opération de convolution, le tenseur d'entrée doit être complété en fonction du pad donné pour éviter que le noyau de convolution ne sorte de la plage au bord du tenseur.
  • Utilisez x_pad dans la fonction pour représenter le tenseur d'entrée rembourré.
  • Implémentation spécifique : Diviser le tenseur d'entrée x en plusieurs tenseurs de formes de (kernel_height, kernel_width) respectivement dans les 2ème et 3ème dimensions (dimensions hauteur et largeur). La longueur de saut entre chaque tenseur est donnée par La foulée est déterminée puis développée en les 2ème et 3ème dimensions respectivement. De cette façon, chaque tenseur développé peut être considéré comme le résultat de convolution locale d'un noyau de convolution bidimensionnel agissant sur x. Ces résultats locaux sont réépissés selon les 2ème et 3ème dimensions pour obtenir un nouveau tenseur x_pad.
  1. Opération de convolution :
  • Utilisez la fonction einsum pour effectuer une opération de convolution sur le nouveau tenseur x_pad.

  • Le premier paramètre de einsum représente la règle de fonctionnement, où ndhw représente les dimensions du tenseur de sortie final (batch_size, out_channels, output_height, output_width), nchw et dckj représentent les dimensions des deux tenseurs d'entrée x_pad et poids, où ckj représente respectivement input_channels, kernel_height et kernel_width.

  • La forme finale du tenseur de sortie est (batch_size, out_channels, output_height, output_width), et un terme de biais est ajouté à chaque position.

  • torch: bibliothèque PyTorch

  • einsum: Notation de sommation d'Einstein, convention de sommation d'Einstein, une notation simple pour la sommation tensorielle.

  • 'nchwkj,dckj->ndhw': Notation de sommation d'Einstein, le tenseur de gauche est x_padet le tenseur de droite est weight. Dans le tenseur de gauche, n, c, h, wreprésentent respectivement la taille du lot, le nombre de canaux, la hauteur et la largeur. Dans le tenseur de droite, d, c, k, jreprésentent respectivement le nombre de canaux de sortie, le nombre de canaux d'entrée, la hauteur du noyau de convolution et la largeur du noyau de convolution. Le sens de cette formule est d' effectuer une opération de convolution x_padsur weightla somme et de générer le tenseur de résultat, dont la forme est (batch_size, output_channels, height, width).

  • x_pad: Tenseur d'entrée avec forme (batch_size, input_channels, input_height, input_width).

  • weight: Tenseur du noyau de convolution avec forme (output_channels, input_channels, kernel_height, kernel_width).

  1. Résultats de retour :
  • Renvoie le tenseur de sortie obtenu après l'opération de convolution.

2.2 Analyse d'exemples de fonctions de convolution écrites

# 设置测试数据
x = torch.randn(2, 3, 5, 5, requires_grad=True)
weight = torch.randn(4, 3, 3, 3, requires_grad=True)
bias = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
x, weight, bias
(tensor([[[[-0.4888,  1.0257,  0.0312, -0.9026, -0.9060],
           [ 0.2071, -0.4962, -0.1658,  1.0919,  0.3785],
           [-0.4654,  1.5442,  0.6005,  0.3594, -2.6207],
           [ 0.5830,  0.0533,  0.5719,  1.5413,  0.5949],
           [-0.9152, -0.2114, -0.4888, -0.0065, -0.9767]],
 
          [[ 0.4706, -0.1108, -0.1563, -1.7946, -0.8533],
           [-0.2119,  0.3165, -2.2668, -0.8956,  1.0617],
           [-0.7809, -0.2120, -0.8592, -0.5057,  0.7954],
           [-2.8820, -0.6888,  0.4450, -0.3586, -0.9477],
           [ 0.6244,  0.4303,  1.4739,  0.2740,  1.6605]],
 
          [[-0.1501,  0.6234, -1.6086,  0.1693,  0.4932],
           [ 1.0611, -1.0938,  0.1695,  1.0193,  0.4263],
           [ 1.4681, -0.1552, -0.0667, -0.7293,  1.0816],
           [ 0.8972,  1.1683, -1.4757,  0.4421, -0.0355],
           [-2.1331,  1.4847,  0.1378, -1.6907, -0.1350]]],
 
 
         [[[-1.3853,  1.6396,  0.3436,  0.3841,  0.2355],
           [-0.2206, -0.5087, -1.6956,  1.3205,  0.7058],
           [ 0.0993,  0.3533, -0.2086,  0.2969,  0.2627],
           [ 0.3752,  0.0304,  1.2487,  1.3963, -0.0063],
           [-1.3758,  0.5088, -1.3849,  1.3050,  0.4150]],
 
          [[ 0.2824, -2.8634, -0.1016, -0.1627,  1.7081],
           [ 0.1406,  0.2220, -0.6005,  0.2997, -0.1846],
           [ 1.6700,  0.5787,  0.6561, -0.0236,  1.7743],
           [ 2.1429, -0.2838, -0.0527,  0.3504, -0.3444],
           [-0.9409, -0.4734, -0.4060, -0.5088, -1.8518]],
 
          [[-2.2152,  0.2104, -0.3302,  0.2036, -0.9443],
           [-0.6576, -0.4455,  0.5117, -2.0058, -1.3985],
           [-0.5688,  1.2338, -0.1832,  0.1760,  0.4506],
           [-0.6563,  0.4021, -1.6210,  0.5582, -0.9238],
           [-1.0506, -0.9638,  0.7453, -0.3535, -0.3536]]]], requires_grad=True),
 tensor([[[[ 0.3069,  0.2079, -0.2952],
           [ 1.7681,  1.1056, -1.0555],
           [ 1.5845,  0.8294,  0.6588]],
 
          [[ 0.2574,  0.5007,  0.2912],
           [-0.0210,  0.6593, -0.9691],
           [-0.2918,  0.5695, -1.1242]],
 
          [[ 0.7327, -0.3453,  0.7041],
           [-0.2236, -1.7762,  0.0190],
           [-1.0927, -2.9369,  0.1768]]],
 
 
         [[[-2.3830, -1.4807,  1.8573],
           [ 1.0097, -0.9640,  1.0361],
           [-0.5222, -1.0386, -0.4016]],
 
          [[ 0.5071,  1.1433, -0.1194],
           [-0.0133, -0.3878, -0.1853],
           [ 0.3456, -0.6502,  0.2221]],
 
          [[-1.7672, -0.0469, -0.5996],
           [-0.2080, -1.6209,  0.4120],
           [ 0.8404, -1.6748, -0.7170]]],
 
 
         [[[ 0.2850,  0.1691, -0.9228],
           [ 0.7234,  0.5582, -0.4327],
           [ 0.6563,  0.2941,  1.5549]],
 
          [[ 0.2642, -1.9061,  1.6212],
           [-0.5276, -0.5608,  0.3824],
           [ 0.4452, -2.5152,  0.4490]],
 
          [[-0.1276,  0.7784,  0.7998],
           [-0.3030, -0.9776,  0.9681],
           [ 1.0225,  0.8946, -0.8084]]],
 
 
         [[[-0.5087, -0.8345, -1.4763],
           [-0.4938,  1.1979, -0.1335],
           [ 0.5010,  0.2865,  0.0728]],
 
          [[-0.3177, -0.6937, -1.0327],
           [ 0.8147, -1.7101, -1.8257],
           [-0.1593, -1.3855, -0.0885]],
 
          [[-0.4687, -1.6307,  1.5791],
           [-1.3030,  0.2004, -0.7055],
           [ 0.0674, -0.8772,  0.1586]]]], requires_grad=True),
 tensor([ 1.5349, -0.5608,  0.5182,  0.3328], requires_grad=True))
n, c, h, w = x.shape
d, c, k, j = weight.shape
n, c, h, w
(2, 3, 5, 5)
d, c, k, j
(4, 3, 3, 3)
# 补零
x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
x_pad
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
x_pad.shape
torch.Size([2, 3, 9, 9])
x_pad[:, :, pad:-pad, pad:-pad] = x
x_pad
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.4888,  1.0257,  0.0312, -0.9026, -0.9060,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.2071, -0.4962, -0.1658,  1.0919,  0.3785,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.4654,  1.5442,  0.6005,  0.3594, -2.6207,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.5830,  0.0533,  0.5719,  1.5413,  0.5949,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.9152, -0.2114, -0.4888, -0.0065, -0.9767,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.4706, -0.1108, -0.1563, -1.7946, -0.8533,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.2119,  0.3165, -2.2668, -0.8956,  1.0617,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.7809, -0.2120, -0.8592, -0.5057,  0.7954,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.8820, -0.6888,  0.4450, -0.3586, -0.9477,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.6244,  0.4303,  1.4739,  0.2740,  1.6605,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.1501,  0.6234, -1.6086,  0.1693,  0.4932,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.0611, -1.0938,  0.1695,  1.0193,  0.4263,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.4681, -0.1552, -0.0667, -0.7293,  1.0816,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.8972,  1.1683, -1.4757,  0.4421, -0.0355,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.1331,  1.4847,  0.1378, -1.6907, -0.1350,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.3853,  1.6396,  0.3436,  0.3841,  0.2355,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.2206, -0.5087, -1.6956,  1.3205,  0.7058,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0993,  0.3533, -0.2086,  0.2969,  0.2627,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.3752,  0.0304,  1.2487,  1.3963, -0.0063,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.3758,  0.5088, -1.3849,  1.3050,  0.4150,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.2824, -2.8634, -0.1016, -0.1627,  1.7081,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.1406,  0.2220, -0.6005,  0.2997, -0.1846,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.6700,  0.5787,  0.6561, -0.0236,  1.7743,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  2.1429, -0.2838, -0.0527,  0.3504, -0.3444,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.9409, -0.4734, -0.4060, -0.5088, -1.8518,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.2152,  0.2104, -0.3302,  0.2036, -0.9443,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.6576, -0.4455,  0.5117, -2.0058, -1.3985,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.5688,  1.2338, -0.1832,  0.1760,  0.4506,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.6563,  0.4021, -1.6210,  0.5582, -0.9238,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.0506, -0.9638,  0.7453, -0.3535, -0.3536,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]]]], grad_fn=<CopySlices>)
# 卷积
x_pad = x_pad.unfold(2, k, stride)
x_pad.shape
torch.Size([2, 3, 4, 9, 3])
x_pad = x_pad.unfold(3, j, stride)
x_pad.shape
torch.Size([2, 3, 4, 4, 3, 3])
out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
out.shape
torch.Size([2, 4, 4, 4])
bias.view(1, -1, 1, 1).shape
torch.Size([1, 4, 1, 1])
# 偏置
out = out + bias.view(1, -1, 1, 1)
out
tensor([[[[ 0.6573, -0.3444,  1.5693, -0.1906],
          [ 2.5483,  5.1142, -2.3528, -3.6162],
          [ 2.9913, -6.1289,  6.8200,  0.9229],
          [ 0.4849,  0.1813,  3.2616,  1.5637]],

         [[-0.1524, -1.2003, -0.3415,  0.0318],
          [-1.7830,  2.5286, -1.6660,  3.1253],
          [ 1.3314, -8.2623, -5.0055,  5.7671],
          [-1.0563,  5.2751, -0.4214,  2.8473]],

         [[ 0.0908,  2.6704,  1.0336,  0.0481],
          [ 0.2077,  2.0459,  1.8095, -0.7039],
          [ 0.9519, -4.5551,  3.7108,  0.7446],
          [ 0.6689,  3.9448,  2.3968,  0.6958]],

         [[ 0.2318, -0.3356,  2.4320,  0.0480],
          [ 0.2101, -1.7177,  6.3956, -0.4108],
          [ 8.2352, -5.8456, 12.9459, -0.8763],
          [-2.3292, -1.5263,  2.1349,  0.3653]]],


        [[[-0.0869,  1.0713,  0.1655,  2.4414],
          [-1.3623, -3.0759,  0.2430,  2.3259],
          [-0.9281,  2.1402,  7.1618,  4.1895],
          [ 0.9273,  1.1176,  0.7792,  0.9265]],

         [[ 1.6465, -1.7187, -0.7251, -0.8871],
          [-1.6260, -0.8628, -1.0122,  3.2737],
          [ 0.5831,  2.1665, -0.5353, -2.0468],
          [-2.3738, -0.1232, -0.0771, -1.8642]],

         [[ 0.2819,  6.0978,  2.9618,  0.4676],
          [ 1.3592,  6.7231,  3.8100,  3.6118],
          [ 0.9885, -5.7760,  5.4375,  0.5480],
          [-0.5778,  1.4657, -2.8315,  0.1923]],

         [[-0.1444,  3.6788,  0.3721,  0.1150],
          [-1.4057,  0.1613, -2.5436,  1.3156],
          [-6.1195,  1.8325,  3.1565,  0.8296],
          [ 1.6766,  6.9403,  1.3986,  0.8758]]]], grad_fn=<AddBackward0>)

2.3 Vérifier l'exactitude de l'écriture de la fonction de convolution

import torch.nn.functional as F
x = torch.randn(2, 3, 5, 5, requires_grad=True)
w = torch.randn(4, 3, 3, 3, requires_grad=True)
b = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
torch_out = F.conv2d(x, w, b, stride, pad)
my_out = conv2d(x, w, b, stride, pad)
torch_out == my_out
tensor([[[[ True,  True,  True,  True],
          [ True, False, False,  True],
          [ True,  True, False,  True],
          [ True,  True, False,  True]],

         [[ True,  True,  True,  True],
          [ True, False,  True,  True],
          [False,  True,  True,  True],
          [ True,  True,  True,  True]],

         [[ True, False, False,  True],
          [False, False,  True,  True],
          [ True, False, False,  True],
          [ True,  True, False,  True]],

         [[ True,  True, False,  True],
          [False, False, False,  True],
          [ True, False, False, False],
          [ True,  True,  True,  True]]],


        [[[ True,  True,  True,  True],
          [False,  True, False,  True],
          [ True,  True, False,  True],
          [ True, False,  True,  True]],

         [[ True,  True, False,  True],
          [ True, False, False, False],
          [ True, False, False, False],
          [ True, False,  True,  True]],

         [[ True,  True, False,  True],
          [False, False, False,  True],
          [ True, False, False, False],
          [ True, False,  True,  True]],

         [[ True, False,  True,  True],
          [ True, False, False,  True],
          [False, False, False, False],
          [ True,  True, False,  True]]]])
torch.allclose(torch_out, my_out, atol=1e-5)
True
  • torch.allclose est une fonction utilisée pour vérifier si les valeurs entre deux tenseurs sont égales.

  • Lorsque vous l'utilisez, vous devez transmettre le premier tenseur comme premier paramètre (c'est-à-dire torch_out), le deuxième tenseur comme deuxième paramètre (c'est-à-dire my_out) et l'erreur absolue autorisée (atol). Transmis comme troisième argument (la valeur par défaut est 1e-8).

  • La fonction renverra une valeur booléenne indiquant si les deux tenseurs ont des valeurs similaires. Si True est renvoyé, cela signifie que les deux tenseurs ont des valeurs numériques similaires, sinon cela signifie qu'il existe une différence numérique entre eux.

grad_out = torch.randn(*torch_out.shape)
grad_x = torch.autograd.grad(torch_out, x, grad_out, retain_graph=True)
my_grad_x = torch.autograd.grad(my_out, x, grad_out, retain_graph=True)
torch.allclose(grad_x[0], my_grad_x[0], atol=1e-5)
True
grad_w = torch.autograd.grad(torch_out, w, grad_out, retain_graph=True)
my_grad_w = torch.autograd.grad(my_out, w, grad_out, retain_graph=True)
torch.allclose(grad_w[0], my_grad_w[0], atol=1e-5)
True
grad_b = torch.autograd.grad(torch_out, b, grad_out, retain_graph=True)
my_grad_b = torch.autograd.grad(my_out, b, grad_out, retain_graph=True)
torch.allclose(grad_b[0], my_grad_b[0], atol=1e-5)
True

Tous sont vrais, ce qui indique que la fonction de convolution écrite a des résultats similaires à ceux de la fonction Conv2d intégrée de PyTorch dans une certaine plage, indiquant l'exactitude de l'implémentation.

Pièce jointe : série d'articles

numéro de série Annuaire d'articles lien direct
1 Pratique d'application PyTorch 1 : implémentation de l'opération de convolution https://want595.blog.csdn.net/article/details/132575530
2 Pratique d'application PyTorch 2 : implémentation d'un réseau neuronal convolutif pour la classification d'images https://want595.blog.csdn.net/article/details/132575702
3 Pratique d'application PyTorch 3 : création d'un réseau neuronal https://want595.blog.csdn.net/article/details/132575758
4 Quatrième pratique d'application PyTorch : créer des applications complexes basées sur PyTorch https://want595.blog.csdn.net/article/details/132625270
5 Cinquième pratique d'application PyTorch : mise en œuvre de réseaux de neurones binaires https://want595.blog.csdn.net/article/details/132625348
6 Pratique d'application PyTorch 6 : Utilisation de LSTM pour implémenter la classification des émotions du texte https://want595.blog.csdn.net/article/details/132625382

Guess you like

Origin blog.csdn.net/m0_68111267/article/details/132575530