Tecnología de gancho de Pytorch: obtenga la salida de la capa intermedia específica del modelo previamente entrenado/entrenado

Para comprender más profundamente el modelo de red neuronal, a veces necesitamos observar información como el núcleo de convolución, el mapa de características o el gradiente obtenido a través de su entrenamiento, que a menudo se utiliza en la investigación de visualización de CNN. Entre ellos, el núcleo de convolución es el más fácil de obtener y se puede obtener guardando los parámetros del modelo; el mapa de características es una variable intermedia y el sistema borrará la imagen correspondiente después del procesamiento; de lo contrario, ocupará mucha memoria; el gradiente es similar al mapa de características, excepto que la estructura de la hoja fuera del punto, los gradientes de otras variables intermedias se liberan de la memoria, por lo que no se pueden obtener directamente.
La forma más sencilla de obtenerlo es cambiar la estructura del modelo. Al final del avance, no solo se devuelve la salida predicha del modelo, sino también el mapa de características requerido y otra información.

Entonces, ¿cómo obtener mapas de características, gradientes y otra información sin cambiar la estructura del modelo (como en el caso de modelos previamente entrenados o modelos entrenados)?
La programación de gancho de Pytorch puede obtener y cambiar de manera efectiva información, como variables intermedias del modelo y gradientes, sin cambiar la estructura de la red.

Hook puede extraer o cambiar el gradiente de Tensor y también puede obtener la salida y el gradiente de nn.Module (no se puede cambiar aquí).
Por lo tanto, se utilizan 3 funciones de enlace para realizar las funciones anteriores:
(1) Objeto tensor :
Tensor.register_hook (gancho)

(2) Objeto del módulo :
nn.Module.register_forward_hook(hook);
nn.Module.register_backward_hook(hook).



A continuación se presenta su uso uno por uno:

Objeto tensorial

La introducción oficial a los objetos Tensor : https://pytorch.org/docs/stable/tensors.html
tiene el siguiente método Register_hook (hook) , que registra un gancho hacia atrás para que Tensor obtenga el gradiente de las variables. El gancho debe seguir el siguiente formato: gancho(grad) -> Tensor o Ninguno , donde grad es el gradiente adquirido.

Función: registre una función de enlace de retropropagación para registrar automáticamente el gradiente de Tensor .
PyTorch liberará automáticamente el gradiente de las variables intermedias y los nodos que no son hoja después de la ejecución para reducir el uso de memoria. ¿Qué son las variables intermedias? ¿Qué es un nodo no hoja?
inserte la descripción de la imagen aquí
En la figura anterior, a, b y d son nodos hoja, y c, e y o son nodos que no son hoja y variables intermedias.

In [1]: a = torch.Tensor([1,2]).requires_grad_() 
    ...: b = torch.Tensor([3,4]).requires_grad_() 
    ...: d = torch.Tensor([2]).requires_grad_() 
    ...: c = a + b 
    ...: e = c * d 
    ...: o = e.sum()     

In [2]: o.backward()

In [3]: print(a.grad)
Out[3]: tensor([2., 2.])

In [4]: print(b.grad)
Out[4]: tensor([2., 2.])

In [5]: print(c.grad)
Out[5]: None

In [6]: print(d.grad)
Out[6]: tensor([10.])

In [7]: print(e.grad)
Out[7]: None

In [8]: print(o.grad)
Out[8]: None

Se puede ver en la salida del programa que a, byd son nodos hoja, y los valores de gradiente aún se conservan después de la propagación hacia atrás, mientras que los gradientes de otros nodos que no son hoja se han liberado automáticamente. valores de gradiente, necesitas usar ganchos .

Primero personalizamos una función de enlace para registrar la operación del gradiente de Tensor , y luego usamos Tensor.register_hook(hook) para registrar el tensor del nodo no hoja para obtener el gradiente , y luego propagamos hacia atrás nuevamente:

In [9]: def hook(grad):
    ...:	print(grad)
    ...:

In [10]: e.register_hook(hook)
Out[10]: <torch.utils.hooks.RemovableHandle at 0x1d139cf0a88>

In [11]: o.backward()

In [12]: print(e.grad)
Out[12]: tensor([1., 1.])

En este momento, el gradiente de e se genera automáticamente.

El nombre de la función de enlace personalizado se puede elegir arbitrariamente y su parámetro es grad , que indica el gradiente de Tensor . Esta función personalizada se utiliza principalmente para describir la operación en el valor de gradiente de Tensor . En el ejemplo anterior, generamos directamente el gradiente, por lo que se imprime (grad). También podemos poner el degradado en una lista o diccionario, e incluso modificar el degradado, de modo que si el degradado es pequeño, se pueda agrandar para evitar que desaparezca:

In [13]: a = torch.Tensor([1,2]).requires_grad_() 
    ...: b = torch.Tensor([3,4]).requires_grad_() 
    ...: d = torch.Tensor([2]).requires_grad_() 
    ...: c = a + b 
    ...: e = c * d 
    ...: o = e.sum()                                                            

In [14]: grad_list = []                                                         

In [15]: def hook(grad): 
    ...:     grad_list.append(grad)    # 将梯度装在列表里
    ...:     return 2 * grad           # 将梯度放大两倍
    ...:                                                                        

In [16]: c.register_hook(hook)                                                  
Out[16]: <torch.utils.hooks.RemovableHandle at 0x7f009b713208>

In [17]: o.backward()                                                           

In [18]: print(grad_list)                                                              
Out[18]: [tensor([2., 2.])]

In [19]: print(a.grad)                                                                 
Out[19]: tensor([4., 4.])

In [20]: print(b.grad)                                                                 
Out[20]: tensor([4., 4.])

En el ejemplo anterior, la función de enlace que definimos realiza dos operaciones: una es cargar el gradiente en la lista grad_list y la otra es duplicar el gradiente. En el resultado, podemos ver que después de realizar la retropropagación, el gradiente de nuestro nodo c no hoja registrado se guarda en la lista grad_list y los gradientes de a y b se duplican. Cabe señalar aquí que si desea almacenar el valor del gradiente en una lista o diccionario, primero debe definir una lista o diccionario de variables globales con el mismo nombre, incluso si es una variable local, debe estar fuera de la costumbre. función de gancho. Otro punto a tener en cuenta es que si desea cambiar el valor del gradiente, la función de enlace debe devolver un valor para devolver el gradiente modificado.

Para resumir aquí, si queremos obtener el valor de gradiente del nodo no hoja Tensor , necesitamos : 1) Personalizar una función de enlace para describir la operación en el gradiente, el nombre de la función es de creación propia y el parámetro es solo grad, que indica el gradiente del Tensor 2) Registre el Tensor para obtener el gradiente con el método Tensor.register_hook(hook) . 3) Realizar retropropagación.


Objeto de módulo

Hay dos métodos, Register_forward_hook(hook) y Register_backward_hook(hook) , que corresponden a las funciones de enlace de propagación hacia adelante y hacia atrás, respectivamente .
Los objetos de operación de estos dos son la clase nn.Module, como la capa convolucional (nn.Conv2d), la capa completamente conectada (nn.Linear), la capa de agrupación (nn.MaxPool2d, nn.AvgPool2d) en la red neuronal y la activación. Pequeños módulos definidos por capas (nn.ReLU) o nn.Sequential, etc.

Para el módulo intermedio del modelo, también se puede considerar como un nodo intermedio (nodo no hoja). Su salida es un mapa de características o un valor de activación. El valor del gradiente de retropropagación será liberado automáticamente por el sistema. Si lo desea para obtenerlos, debes usar la función de gancho .

Como puede ver en el nombre, Register_forward_hook es para obtener la salida de la propagación hacia adelante, es decir, el mapa de características o el valor de activación; Register_backward_hook es para obtener la salida de la propagación hacia atrás, es decir, el valor del gradiente. Su uso es similar al de Register_hook presentado anteriormente . La función de gancho debe eliminarse a tiempo después de su uso para evitar ejecutar el gancho cada vez para aumentar la carga de funcionamiento.

(1) Para Register_forward_hook(hook) , la función de enlace se define de la siguiente manera:

# 这里有3个参数,分别表示:模块,模块的输入,模块的输出。
# 函数用于描述对这些参数的操作,一般我们都是为了获取特征图,即只描述对output的操作即可。
def forward_hook(module, input, output):
    operations

Hook puede modificar la entrada y la salida, pero no afectará el resultado del reenvío. El escenario más comúnmente utilizado es extraer las características de salida de una determinada capa (no la última capa) del modelo, pero no desea modificar su archivo de definición de modelo original, entonces puede usar la función forward_hook .

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,6,3,1,1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,9,3,1,1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(8*8*9, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120,10)

    def forward(self, x):
        out = self.pool1(self.relu1(self.conv1(x)))
        out = self.pool2(self.relu2(self.conv2(out)))
        out = out.view(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)

        return out

fmap_block = dict()  # 装feature map

def forward_hook(module, input, output):
    fmap_block['input'] = input
    fmap_block['output'] = output

net = Net()
net.load_state_dict(torch.load('checkpoint.pth')) # 载入已训练好的模型
x = torch.randn(1, 3, 32, 32).requires_grad_() # 随机生成一副图像作为输入
handle = net.conv2.register_forward_hook(hook) # 注册hook
y = net(x)

# 展示输入图像和特定中间层的特征
plt.subplot(121)
plt.imshow(fmap_block['input'][0][0,0,:,:].cpu().detach().numpy())
plt.subplot(122)
plt.imshow(fmap_block['output'][0][0,0,:,:].cpu().detach().numpy())
plt.show()
print((fmap_block['input'][0].shape))
print((fmap_block['output'][0].shape))

handle.remove()

(2) Para Register_backward_hook(hook) , la función de enlace se define de la siguiente manera:

# 这里有3个参数,分别表示:模块,模块输入端的梯度,模块输出端的梯度。
def backward_hook(module, grad_in, grad_out):
    operations

Lo que necesita atención especial aquí es que la entrada y la salida aquí son la entrada y la salida durante la propagación directa, es decir, el gradiente de la salida anterior corresponde al grad_out aquí. Por ejemplo, módulo lineal: o=W*x+b, su terminal de entrada es W, x y b, y su terminal de salida es o.

grad_in y grad_out pueden ser de tipo tupla si el módulo tiene múltiples entradas o salidas . Para el módulo lineal: o=W*x+b, su entrada incluye W, x y b, por lo que grad_input es una tupla que contiene tres elementos.

Note aquí la diferencia con el gancho delantero:

  1. En el gancho directo, la entrada es x, excluyendo W y b.
  2. Al devolver Tensor o Ninguno, la función de enlace inverso no puede cambiar directamente sus variables de entrada, pero puede devolver un nuevo
    grad_in y propagarse hacia atrás a su módulo anterior.
import torch
import torch.nn as nn
import numpy as np 
import torchvision.transforms as transforms

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,6,3,1,1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,9,3,1,1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(8*8*9, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120,10)

    def forward(self, x):
        out = self.pool1(self.relu1(self.conv1(x)))
        out = self.pool2(self.relu2(self.conv2(out)))
        out = out.view(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)

        return out

fmap_block = dict()  # 装feature map
grad_block = dict()  # 装梯度
def forward_hook(module, input, output):
    fmap_block['input'] = input
    fmap_block['output'] = output
    
def backward_hook(module, grad_in, grad_out):
    grad_block['grad_in'] = grad_in
    grad_block['grad_out'] = grad_out

loss_func = nn.CrossEntropyLoss()

label = torch.empty(1, dtype=torch.long).random_(3)  # 生成一个假标签
input_img = torch.randn(1,3,32,32).requires_grad_()  # 生成一副假图像作为输入 

net = Net()

# 注册hook
handle_forward = net.conv2.register_forward_hook(farward_hook)
handle_backward = net.conv2.register_backward_hook(backward_hook)

outs = net(input_img)
loss = loss_func(outs, label)
loss.backward()

handle_forward.remove()
handle_backward.remove()
print('End.')

En el programa anterior, primero definimos un modelo de red neuronal convolucional simple: registramos el gancho para el módulo convolucional de segunda capa, obtenemos su entrada y salida, obtenemos el gradiente de entrada y salida, y los instalamos en el diccionario. Para lograr el efecto de verificación, generamos aleatoriamente una imagen falsa cuyo tamaño es el mismo que el tamaño de la imagen del conjunto de datos cifar-10, definimos una etiqueta de categoría para la imagen falsa y usamos la función de pérdida de retropropagación para simular el Proceso de entrenamiento de la red neuronal.

Después de ejecutar el programa, el mapa de características y el gradiente correspondientes aparecerán en las dos listas fmap_block y grad_block . Veamos las dimensiones de su entrada y salida:

In [21]: print(len(fmap_block['input']))                                               
Out[21]: 1

In [22]: print(len(fmap_block['output']))                                              
Out[22]: 1

In [23]: print(len(grad_block['grad_in']))                                             
Out[23]: 3

In [24]: print(len(grad_block['grad_out']))                                            
Out[24]: 1

Se puede ver que solo hay una entrada y salida del módulo convolucional de la segunda capa, es decir, el mapa de características correspondiente. Hay tres valores de gradiente en el extremo de entrada, que son el gradiente de peso, el gradiente de desviación y el gradiente del mapa de características de entrada. Solo hay un valor de gradiente en la salida, es decir, el gradiente del mapa de características de salida. Como se enfatizó anteriormente, incluso si hay W, X y b en el extremo de entrada, solo X es la entrada para la propagación anterior, y para la propagación hacia atrás, las tres son entradas. ¿Cuál es el orden en que se organizan los valores de gradiente de los tres elementos en el extremo de salida? Echemos un vistazo a las dimensiones específicas de los tres gradientes:

In [25]: print(grad_block['grad_in'][0].shape)                                         
Out[25]: torch.Size([1, 6, 16, 16])

In [26]: print(grad_block['grad_in'][1].shape)                                        
Out[26]: torch.Size([9, 6, 3, 3])

In [27]: print(grad_block['grad_in'][2].shape)                                         
Out[27]: torch.Size([9])

A juzgar por la dimensión del gradiente en el extremo de salida, el primero es obviamente el gradiente del mapa de características, el segundo es el gradiente del peso (núcleo de convolución / filtro) y el tercero es el gradiente del sesgo. Para verificar que el gradiente tiene las mismas dimensiones que estos parámetros, veamos las dimensiones de la propagación hacia adelante de estos tres valores:

In [28]: print(fmap_block['input'][0].shape)                                           
Out[28]: torch.Size([1, 6, 16, 16])

In [29]: print(net.conv2.weight.shape)         
Out[29]: torch.Size([9, 6, 3, 3])

In [30]: print(net.conv2.bias.shape)                                                   
Out[30]: torch.Size([9])

Lo último a tener en cuenta es que si necesita obtener el gradiente de la imagen de entrada, debe establecer el atributo require_grad del tensor de entrada en Verdadero.

Blog de referencia:
Pytorch obtiene información de la capa intermedia: función de enlace
GANCHO de PyTorch: una herramienta eficaz para obtener características y gradientes de redes neuronales

Supongo que te gusta

Origin blog.csdn.net/Joker00007/article/details/128943862
Recomendado
Clasificación