PyTorch的hook函数(register_hook、register_forward_hook、register_backward_hook、register_forward_pre_hook)

1. función de gancho

1.1 El concepto de función de gancho.

mecanismo de función de gancho: No cambia el cuerpo principal, implementa funciones adicionales, como un colgante, gancho → gancho.

Entonces, ¿por qué existe un mecanismo de función de gancho? Esto está relacionado con el mecanismo de ejecución de gráficos dinámicos de PyTorch. En el mecanismo de ejecución del gráfico dinámico, cuando se completa la operación, se liberarán algunas variables intermedias, como mapas de características y gradientes de nodos que no son hojas. Pero a veces queremos seguir prestando atención a estas variables intermedias, entonces podemos usar la función de enlace para extraer las variables intermedias en el código principal. El código principal es principalmente la propagación hacia adelante y hacia atrás del modelo.

En pocas palabras, la función de enlace no modifica el cuerpo principal, sino que implementa funciones adicionales. Correspondiente a PyTorch, el cuerpo principal es forward y backward, y la función adicional es operar las variables del modelo, como:

  1. "Extraer" mapas de características
  2. "Extraer" el gradiente de un tensor que no es hoja
  3. Modificar el gradiente del tensor

Dé un ejemplo para demostrar cómo el gancho extrae el gradiente de un tensor que no es de hoja:

import torch


# 定义钩子操作
def grad_hook(grad):
    y_grad.append(grad)
    
# 创建一个list来保存钩子获取的梯度
y_grad = list()

# 创建输入变量
x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)  # requires_grad=True表明该节点为叶子节点
y = x + 1  # 没有明确requires_grad=True,所以是非叶子节点

# 为非叶子节点y注册钩子
y.register_hook(grad_hook)  # 这是传入的是函数而非函数的调用

# y.retain_grad()  # 如果想要将 y 设置为叶子节点,可以设置 y.retain_grad()

# 计算z节点(z 也是一个非叶子节点,因为它是通过对非叶子节点 y 进行操作而得到的)
z = torch.mean(y * y)

# 反向传播
z.backward()

print(f"y.type: {
      
      type(y)}")
print(f"y.grad: {
      
      y.grad}")
print(f"y_grad: {
      
      y_grad}")

El resultado es el siguiente:

y.type: <class 'torch.Tensor'>
y.grad: None
y_grad: [tensor([[1.0000, 1.5000],
        [2.0000, 2.5000]])]

Puedes ver que el valor de y. grad es None, esto se debe a que y es un nodo que no es hoja. tensor, en Una vez completado z. backward(), el gradiente de y se libera para ahorrar memoria, pero el gradiente de y se puede extraer mediante el método de clase Register_hook de la antorcha. Tensor.


Aquí PyTorch puede informar una advertencia:

/root/anaconda3/envs/wsss/lib/python3.9/site-packages/torch/_tensor.py:1083: 
UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). 
If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. 
If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. 
See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:482.)

Traducido es:

/root/anaconda3/envs/wsss/lib/python3.9/site-packages/torch/_tensor.py:1083: UserWarning: 
正在访问不是叶张量的张量的 .grad 属性。 在 autograd.backward() 期间不会填充其 .grad 属性。 
如果我们确实希望为非叶张量填充 .grad 字段,请在非叶张量上使用 .retain_grad() 。 
如果我们错误地访问了非叶张量,请确保我们改为访问叶张量。 
有关更多信息,请参阅 github.com/pytorch/pytorch/pull/30531。 (在 aten/src/ATen/core/TensorBody.h:482 内部触发。)

Esta advertencia suele aparecer cuando intentamos acceder a la información de gradiente de un tensor que no es un nodo hoja. PyTorch nos recuerda que la información de gradiente no está disponible para nodos que no son hoja. Si necesitamos usar información de gradiente en nodos que no son hojas, podemos usar el método .retain_grad() para habilitar el registro de información de gradiente. De lo contrario, asegúrese de que nuestras operaciones se realicen en nodos hoja para permitir el acceso normal a la información de gradiente.


Para ignorar advertencias específicas en Python, puede utilizar el módulo warnings. En este caso, podemos ignorar la advertencia de usuario de PyTorch de la siguiente manera:

import warnings

# 忽略特定的警告
warnings.filterwarnings("ignore", category=UserWarning, module="torch")

El código anterior filtrará las advertencias de la categoría UserWarning del módulo torch para que ya no se muestren. Tenga en cuenta que ignorar las advertencias es una operación global y, por lo tanto, puede afectar todo nuestro entorno Python. Asegúrese de elegir sabiamente al utilizar este método para evitar ocultar información de advertencia importante.

1.2 Función de gancho proporcionada por PyTorch

  1. torch.Tensor.register_hook (Python method, in torch.Tensor)
  2. torch.nn.Module.register_forward_hook (Python method, in torch.nn)
  3. torch.nn.Module.register_backward_hook (Python method, in torch.nn)
  4. torch.nn.Module.register_forward_pre_hook (Python method, in torch.nn)

Uno de estos 4 ganchos se aplica a tensor y los otros 3 son para nn.Module.

1.2.1 Tensor.register_hook

  • Función: registra una función de enlace de retropropagación. Debido a que Tensor solo realiza retropropagación, si no es un nodo hoja, se liberará su propio gradiente. Entonces esta función de gancho está especialmente diseñada para Tensor.

    Tensorregister_hookEl método de la clase register_hook:

  • Idioma

    def hook(grad):
        ...
    
    tensor.register_hook(hook)
    
  • Acción:

    • register_hookEl objetivo principal del método es permitir al usuario registrar una función personalizada en el cálculo del gradiente del tensor para operar en el gradiente o registrar información durante la retropropagación.

    • Esto es útil para tareas como el procesamiento de gradientes personalizados, el recorte de gradientes, la visualización de información de gradientes y la modificación de gradientes.

  • Responder

    • register_hookEl valor de retorno del método es un identificador de gancho que se puede usar para cancelar el gancho de gradiente. Al llamar al método remove(), podemos cancelar el gancho de gradiente registrado en cualquier momento para evitar pérdidas de memoria.
  • Ejemplos de escenarios de aplicación: en la función de enlace, la graduación del gradiente se puede realizar in situ para modificar el valor de graduación del tensor. Esta es una función interesante. Por ejemplo, cuando el gradiente de la capa poco profunda desaparece, el gradiente de la capa poco profunda se puede multiplicar por un cierto múltiplo para aumentar el gradiente; el gradiente también se puede truncar para limitar el gradiente a un cierto intervalo. para evitar Si el gradiente es demasiado grande, se modificarán los parámetros de peso.

  • ejemplo

    • Ejemplo 1: Obtener el gradiente de la variable intermedia y
    • Ejemplo 2: use la función de enlace para expandir el gradiente de la variable x 2 veces

Insertar descripción de la imagen aquí

"""例 1:获取中间变量 a 的梯度"""
import torch
import warnings


# 忽略特定的警告
warnings.filterwarnings("ignore", category=UserWarning, module="torch")


# 自定义hook操作: 梯度处理或记录操作
def grad_hook(grad):
    a_grad.append(grad)


if __name__ == "__main__":
    w = torch.tensor([1.], requires_grad=True)  # 定义叶子节点
    x = torch.tensor([2.], requires_grad=True)  # 定义叶子节点
    a = torch.add(w, x)  # 非叶子节点
    b = torch.add(w, 1)  # 非叶子节点
    y = torch.mul(a, b)  # 非叶子节点

    # 存放梯度
    a_grad = []
    
    # 注册梯度钩子
    handle = a.register_hook(grad_hook)
    
    # 反向传播
    y.backward()

    # 查看梯度
    print(f"w.grad: {
      
      w.grad}")  # tensor([5.])
    print(f"x.grad: {
      
      x.grad}")  # tensor([2.])
    print(f"a.grad: {
      
      a.grad}")  # None
    print(f"b.grad: {
      
      b.grad}")  # None
    print(f"y.grad: {
      
      y.grad}")  # None
    print(f"a_grad: {
      
      a_grad}")  # [tensor([2.])]
    print(f"a_grad[0]: {
      
      a_grad[0]}")  # tensor([2.])
    
    # 取消钩子,避免内存泄漏
    handle.remove()
w.grad: tensor([5.])
x.grad: tensor([2.])
a.grad: None
b.grad: None
y.grad: None
a_grad: [tensor([2.])]
a_grad[0]: tensor([2.])

En el ejemplo anterior, la función grad_hook se registra en el tensor tensor y se activa durante la retropropagación, almacenando su gradiente en . a_grad lista, preservando así su información de gradiente.

"""例 2:利用 hook 函数将变量 x 的梯度扩大 2 倍"""
import torch
import warnings


# 忽略特定的警告
warnings.filterwarnings("ignore", category=UserWarning, module="torch")

# 定义钩子操作
def grad_hook(grad):
    grad *= 2
    return grad


# 创建输入变量
x = torch.tensor([2., 2., 2., 2.], requires_grad=True)  # requires_grad=True表明该节点为叶子节点
y = torch.pow(x, 2)  # 没有明确requires_grad=True,所以是非叶子节点
z = torch.mean(y)  # 对非叶子节点 y 进行操作,所以 z 也不是叶子节点

# 为非叶子节点 y 注册钩子, 返回值为 Handler
handler = x.register_hook(grad_hook)

# 反向传播
z.backward()

print(f"x.grad: {
      
      x.grad}")

# 取消梯度钩子
handler.remove()
x.grad: tensor([2., 2., 2., 2.]

El gradiente original de x es tensor([1., 1., 1., 1. ]). Después de la operación grad_hook, el gradiente es tensor([2., 2., 2., 2. ]).

En resumen, el método register_hook nos permite personalizar la lógica de procesamiento de gradiente en PyTorch, agregando control y funcionalidad adicionales a los cálculos de gradiente.

1.2.2 nn.Module.register_forward_hook

nn.Module.register_forward_hook es un método en PyTorch que se utiliza para registrar una función de devolución de llamada (gancho) durante la propagación hacia adelante del módulo de red neuronal (nn.Module). Esto nos permite capturar la entrada y salida del módulo para operaciones personalizadas o información de registro.

  • Función

    1. Monitorear la entrada y salida del módulo: Al registrar el gancho de propagación directa, podemos capturar la entrada y salida del módulo de red neuronal. Esto es útil para comprender cómo un módulo procesa datos y para monitorear estados intermedios.

    2. Registrar estados intermedios: Podemos capturar los estados intermedios del módulo durante la propagación directa, como la salida de la capa oculta. Esto es muy útil para visualizar características intermedias, depuración de modelos e interpretabilidad de modelos.

    3. Operaciones personalizadas: Podemos realizar operaciones personalizadas en el enlace de propagación hacia adelante, como modificar la salida del módulo, agregar ruido o realizar otro procesamiento personalizado.

    4. Aplicación de tareas específicas: En algunas tareas, el gancho de propagación hacia adelante se puede utilizar para implementar funciones específicas, como realizar un procesamiento específico en la salida de la función de activación, o Intermedio las características se pasan a otros módulos.

  • Idioma

    def hook(module, input, output) -> None:
        ...
    
    model/layer.register_forward_hook(hook)
    
    • module:Capa de red actual

    • input: Datos de entrada de la capa de red actual

    • output: Datos de salida de la capa de red actual

    Tenga en cuenta que la entrada y la salida no se pueden modificar

  • Ejemplos de escenarios de aplicación: utilizados para extraer mapas de características
    Supongamos que la red consta de capas convolucionalesconv1 y la capa de agrupación pool1 se compone de, ingrese uno 4 × 4 4\times4 4×La imagen de 4 ahora se usa para obtener forward_hook de module después de conv1 Mapas de características, el diagrama esquemático es el siguiente:

    Insertar descripción de la imagen aquí

    import torch
    import torch.nn as nn
    import warnings
    
    
    # 忽略特定的警告
    warnings.filterwarnings("ignore", category=UserWarning, module="torch")
    
    
    class CustomModel(nn.Module):
        def __init__(self):
            super(CustomModel, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3)
            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x
    
    
    def forward_hook(module, input, output):
        outputs_fmap_list.append(output)
        inputs_fmap_list.append(input)
    
    
    if __name__ == "__main__":
        # 初始化网络
        model = CustomModel()
        model.conv1.weight[0].data.fill_(1)
        model.conv1.weight[1].data.fill_(2)
        model.conv1.bias.data.zero_()
    
        # 定义保存输入、输出feature maps的list
        outputs_fmap_list = list()
        inputs_fmap_list = list()
    
        # 注册hook
        model.conv1.register_forward_hook(forward_hook)
    
        # 模型前向推理
        dummy_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
        output = model(dummy_img)
    
        # 观察
        print(f"output shape: \n\t{
            
            output.shape}")
        print(f"output value: \n\t{
            
            output}")
        print("---" * 30)
        print(f"feature maps shape: \n\t{
            
            outputs_fmap_list[0].shape}")
        print(f"output value: \n\t{
            
            outputs_fmap_list[0]}")
        print("---" * 30)
        print(f"input shape: \n\t{
            
            inputs_fmap_list[0][0].shape}")
        print(f"input value: \n\t{
            
            inputs_fmap_list[0]}")
    
    output shape: 
            torch.Size([1, 2, 1, 1])
    output value: 
            tensor([[[[ 9.]],
    
             [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward0>)
    ------------------------------------------------------------------------------------------
    feature maps shape: 
            torch.Size([1, 2, 2, 2])
    output value: 
            tensor([[[[ 9.,  9.],
                      [ 9.,  9.]],
                      
                     [[18., 18.],
                      [18., 18.]]]], grad_fn=<ConvolutionBackward0>)
    ------------------------------------------------------------------------------------------
    input shape: 
            torch.Size([1, 1, 4, 4])
    input value: 
            (tensor([[[[1., 1., 1., 1.],
                       [1., 1., 1., 1.],
                 	   [1., 1., 1., 1.],
              		   [1., 1., 1., 1.]]]]),)
    

    Primero inicialice una red. La capa de convolución tiene dos núcleos de convolución, los pesos son todos 1 y 2, el sesgo se establece en 0 y la capa de agrupación adopta 2 × 2 2\veces2 2×Agrupación máxima de 2.
    Antes de continuar con forward, registre la función para el módulo conv1 y luego ejecute el Propagación hacia adelante (), una vez completada la propagación hacia adelante, el primer elemento en la lista es el mapa de características generado por capa. Tenga en cuenta aquí que la función tiene dos variables: y , y la característica map es es esta variable, y son los datos de entrada de la capa y la entrada de es una forma de tupla. forward_hookoutput = model(dummy_img)outputs_fmap_listconv1
    forward_hookinputoutputoutputinputconv1conv1


Analicemos cómo el módulo llama a la función de enlace.

  1. model es uno module 类、对 module 执行 module(input)output = model(dummy_img))es Para conferencia module.call

  2. Y module.__call__ el proceso de ejecución es el siguiente:

    def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            hook(self, input)
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                raise RuntimeError(
                    "forward hooks should never return any values, but '{}'"
                    "didn't return None".format(hook))
    ...
    
    • Primer juicio module(inmediatamente model)correcto o incorrecto forward_pre_hook(en el país forward 之前的hook);

    • luego ve forward

    • forwardLlega solo después de que haya terminado.forward_hook

      Pero nota aquí, lo que se está ejecutando ahora esmodel.call, el gancho que compusimos está en el módulo model.conv1,
      Entonces el segundo salto es en model.__call__result = self.forward(*input, **kwargs)

  3. model.forward

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x
    

    En model.forward, self.conv1(x) se ejecuta primero y conv1 es un nn.Conv2d (También una clase de módulo). Como se mencionó en el paso 1, al ejecutar module(input) en el módulo se llamará module.call

  4. nn.Conv2d.call

    El proceso en nn.Conv2d.__call__ es el mismo que en el paso 2:

    def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            hook(self, input)
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                raise RuntimeError(
                    "forward hooks should never return any values, but '{}'"
                    "didn't return None".format(hook))
    

    Aquí finalmente tenemos que ejecutar la forward_hook función que registramos, aquí mismo hook_result = hook(self, input, result). En este punto debemos prestar atención a dos puntos:

    1. hook_result = hook(self, input, result)input¡Los valores de y result en . Aunque estos datos se pueden manipular mediante ganchos, estos valores no se pueden modificar, de lo contrario se destruirá el cálculo del modelo. , la entrada son los datos de entrada de esta capa y el resultado es el mapa de características de salida después de la operación de la capa función, en ; el resultado corresponde a la salida en corresponde a la entrada en la función
      Aquí no se pueden modificar! inputforward_hookforward_hookconv1conv1

    2. La función de enlace registrada no puede generar un valor de retorno; de lo contrario, se generará una excepción.

      if hook_result is not None:
      	raise RuntimeError
      

Para resumir el proceso de llamada:

model(dummy_img) --> model.call:
    result = self.forward(*input, **kwargs)
--> model.forward: 
    x = self.conv1(x)
--> conv1.call:
    hook_result = hook(self, input, result)  # hook就是我们注册的forward_hook函数了

1.2.3 nn.Module.register_forward_pre_hook

nn.Module.register_forward_pre_hookEs un método en PyTorch que se utiliza para registrar una función de devolución de llamada (gancho) antes del proceso de propagación hacia adelante del módulo de red neuronal. Esto nos permite realizar operaciones personalizadas o registrar información antes de que comience el avance del módulo. Aquí están los detalles sobre el método:

  • Definición:register_forward_pre_hook El método se utiliza para registrar el preenganche directo, permitiéndonos hacerlo antes de la propagación hacia adelante del módulo. comportamiento.

  • Habilidad sobrenatural

    • El preenganche de paso directo nos permite monitorear la entrada de un módulo pero realizar operaciones antes de que el módulo realice el cálculo del paso directo.

    • Esto es útil para realizar modificaciones durante el paso hacia adelante, monitorear entradas o registrar información.

  • Idioma

    def hook(module, input) -> None:
        ...
    
    model/layer.register_forward_pre_hook(hook)
    

    hook es una función definida por el usuario que se ejecutará antes del avance del módulo. Esta función acepta dos parámetros: module y input, que representan el módulo de red neuronal y la entrada del módulo respectivamente.

Ejemplo: Aquí hay un ejemplo que muestra cómo utilizar el método register_forward_pre_hook:

import torch
import torch.nn as nn
import warnings


# 忽略特定的警告
warnings.filterwarnings("ignore", category=UserWarning, module="torch")

# 定义前向传播预钩子函数
def forward_pre_hook(module, input):
    print(f"Module: {
      
      module.__class__.__name__}")
    print(f"Input: {
      
      input}")

# 创建一个神经网络模块
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, x):
        return self.fc(x)

# 创建模块实例
model = MyModule()

# 创建一个输入
input_data = torch.randn(1, 3, requires_grad=True)

# 注册前向传播预钩子
hook_handle = model.fc.register_forward_pre_hook(forward_pre_hook)

# 执行前向传播
output = model(input_data)
Module: Linear
Input: (tensor([[-0.2644,  1.9462,  2.2998]], requires_grad=True),)

En este ejemplo, capturamos la información de entrada del módulo en el preenganche del paso directo y realizamos una operación personalizada antes de que comience el proceso de paso directo. Esto nos permite realizar operaciones o registrar entradas antes de que se ejecute el módulo, y luego el módulo procede a realizar cálculos normales de paso directo.

1.2.4 nn.Module.register_backward_hook

nn.Moduleregister_backward_hookEl método en nos permite registrar una función de devolución de llamada personalizada durante el proceso de retropropagación del modelo PyTorch para realizar operaciones adicionales al calcular gradientes. Esto es útil para monitorear y modificar gradientes, realizar análisis de gradientes o realizar otras operaciones personalizadas.

  • Idioma

    def hook(module, grad_input, grad_output) -> Tensor or None:
    	...
        
    model/layer.register_backward_hook(hook)
    
  • Explicación numérica

    • module: Representa una capa o módulo en el modelo.
    • grad_input: una tupla que contiene el gradiente de entrada
    • grad_output: una tupla que contiene el gradiente de salida.

Las funciones de devolución de llamada se utilizan generalmente para realizar algunas operaciones específicas, como grabar, analizar o modificar gradientes. Podemos registrar diferentes funciones de devolución de llamada en diferentes capas del modelo para realizar diferentes operaciones para diferentes capas cuando sea necesario.

  • Ejemplo de escenario de aplicación: extraer el gradiente del mapa de características
    Utiliceregister_backward_hook para extraer el gradiente del mapa de características, y combinado con el método Grad-CAM (visualización de mapa de activación de clase basada en gradiente de clase), se visualiza el modo de aprendizaje de la red neuronal convolucional.

Aquí hay un ejemplo de cómo utilizar register_backward_hook:

import torch
import torch.nn as nn
import warnings


# 忽略特定的警告
warnings.filterwarnings("ignore", category=UserWarning, module="torch")

# 自定义的回调函数
def hook(module, grad_input, grad_output):
    # 打印梯度信息
    print(f"Module: \n{
      
      module}\n")
    print(f"Input Gradient: \n{
      
      grad_input}\n")
    print(f"Output Gradient: \n{
      
      grad_output}\n")

# 创建模型
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

if __name__ == "__main__":
    model = CustomModel()
    loss_fn = nn.MSELoss()

    # 在模型的某一层上注册回调函数
    handler = model.conv1.register_backward_hook(hook)

    # 前向传播和反向传播
    dummy_img = torch.ones((1, 1, 4, 4))
    label = torch.randint(0, 10, size=[1, 2], dtype=torch.float)
    output = model(dummy_img)
    loss = loss_fn(output, label)  # 将输出转化为标量值
    loss.backward()  # 对损失进行反向传播


    # 注销回调函数
    handler.remove()
Module: Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))

Input Gradient: 
(None, tensor([[[[-2.7914, -2.7914, -2.7914],
          		 [-2.7914, -2.7914, -2.7914],
          		 [-2.7914, -2.7914, -2.7914]]],


        		[[[-1.2109, -1.2109, -1.2109],
          		  [-1.2109, -1.2109, -1.2109],
          		  [-1.2109, -1.2109, -1.2109]]]]), tensor([-2.7914, -1.2109]))

Output Gradient: 
(tensor([[[[-2.7914,  0.0000],
           [ 0.0000,  0.0000]],

          [[-1.2109,  0.0000],
           [ 0.0000,  0.0000]]]]),)

En este ejemplo, creamos una función de devolución de llamada personalizada custom_backward_hook y la registramos en una capa del modelo. Luego, hacemos propagación hacia adelante y hacia atrás e imprimimos la información del gradiente en la función de devolución de llamada. Finalmente, cancelamos el registro de la función de devolución de llamada para asegurarnos de que ya no se ejecute en retropropagaciones posteriores.

Al usar register_backward_hook, podemos monitorear y manipular gradientes, realizar análisis de gradientes o realizar otras operaciones personalizadas relacionadas con gradientes, lo cual es muy útil para la depuración y optimización de modelos.

2. Función de gancho y extracción de mapas de características.

"""
    使用hook函数可视化特征图
"""
import torch
import torch.nn as nn
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
import torchvision.utils as vutils
import matplotlib.pyplot as plt


def set_seed(seed):
    # 设置Python内置的随机数生成器的种子
    random.seed(seed)

    # 设置NumPy的随机数生成器的种子
    np.random.seed(seed)

    # 设置PyTorch的随机数生成器的种子(如果使用了PyTorch)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


if __name__ == "__main__":
    set_seed(1)  # 设置随机数种子
    # 实例化 Tensorboard
    writer = SummaryWriter(comment="test_your_comment", filename_suffix="_test_your_filename_suffix")

    # 读取数据并进行预处理
    image_path = "lena.png"
    MEAN = [0.49139968, 0.48215827, 0.44653124]
    STD = [0.24703233, 0.24348505, 0.26158768]

    normalization = transforms.Normalize(MEAN, STD)
    image_transforms = transforms.Compose([transforms.Resize(size=(224, 224)),
                                           transforms.ToTensor(),
                                           normalization])  # 标准化一定要在ToTensor之后进行

    # 使用 Pillow 读取图片
    image_pillow = Image.open(fp=image_path).convert('RGB')

    # 对图片进行预处理
    if image_transforms:
        image_tensor = image_transforms(image_pillow)

    # 添加 Batch 维度
    image_tensor = torch.unsqueeze(input=image_tensor, dim=0)  # [C, H, W] -> [B, C, H, W]

    # 创建模型
    alexnet = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)  # <=> pretrained=True

    # 注册hook
    fmap_dict = dict()  # 存放所有卷积层的特征图
    for name, sub_module in alexnet.named_modules():
        """
            name: features.0
            sub_module: Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
        """
        if isinstance(sub_module, nn.Conv2d):  # 判断是否为卷积层,如果是则注册hook
            key_name = str(sub_module.weight.shape)  # torch.Size([64, 3, 11, 11])
            fmap_dict.setdefault(key_name, list())
            layer_name, index = name.split('.')  # 'features', 0

            def hook_func(m, i, o):  # m: module; i: input; o: output
                key_name = str(m.weight.shape)
                fmap_dict[key_name].append(o)

            # 给 nn.Conv2d层 添加hook函数
            # alexnet._modules[layer_name]._modules[index] -> Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            alexnet._modules[layer_name]._modules[index].register_forward_hook(hook_func)

    # forward(在执行模型时会自动执行hook函数从而往fmap_dict字典中存放输出特征图)
    output = alexnet(image_tensor)

    # 添加图像
    for layer_name, fmap_list in fmap_dict.items():
        # layer_name: torch.Size([1, 3, 224, 224])
        # len(fmap_list): 1 -> shape: [B, C, H, W]
        fmap = fmap_list[0]  # 去掉[]
        fmap.transpose_(0, 1)  # [B, C, H, W] -> [C, B, H, W]

        nrow = int(np.sqrt(fmap.shape[0]))  # 开根号获取行号
        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)  # [3, 458, 458]

        # 将结果存放到 Tensorboard 中
        writer.add_image(tag=f"feature map in {
      
      layer_name}", img_tensor=fmap_grid, global_step=1)

        # 也可以将结果直接用 Matplotlib 读取
        # 创建一个图像窗口
        plt.figure(figsize=(8, 8))

        # 使用imshow函数显示 grid_image
        plt.imshow(vutils.make_grid(fmap_grid, normalize=True).permute(1, 2, 0))  # 注意permute的用法,将通道维度移到最后
        plt.axis('off')  # 不显示坐标轴

        # 显示图像
        plt.savefig(f"feature map_in_{
      
      layer_name}.png")

Resultado

Insertar descripción de la imagen aquí

característica map_in_torch.Size([64, 3, 11, 11]).png

Insertar descripción de la imagen aquí

característica map_in_torch.Size([192, 64, 5, 5]).png

Insertar descripción de la imagen aquí

característica map_in_torch.Size([256, 256, 3, 3]).png

Insertar descripción de la imagen aquí

característica map_in_torch.Size([256, 384, 3, 3]).png

Insertar descripción de la imagen aquí

característica map_in_torch.Size([384, 192, 3, 3]).png

【Conocimiento de desarrollo】1. dict.setdefault(key, default)efecto

dict.setdefault(key, default) es un método del diccionario de Python (dict) que establece el valor predeterminado de una clave en el diccionario. Si la clave especificada key existe en el diccionario, este método devuelve el valor asociado con la clave. Si la clave especificada key no existe en el diccionario, agrega la clave al diccionario, establece su valor en default y devuelve default. Este método nos permite establecer valores predeterminados para las claves en el diccionario para evitar generar una excepción KeyError al acceder a una clave inexistente.

【Conocimientos de desarrollo】2. alexnet._modulesalexnet.module ¿Eres diferente?

En PyTorch, alexnet._modules y alexnet.module representan cosas diferentes:

  1. alexnet._modules:

    • alexnet._moduleses un diccionario que contiene cada submódulo del modelo AlexNet. Cada submódulo se almacena por nombre y se puede acceder a él mediante una clave de diccionario. Estos submódulos incluyen capas convolucionales, capas completamente conectadas, capas de agrupación, etc. Podemos utilizar esta forma para ver y acceder a los diferentes componentes de AlexNet.
    • Por ejemplo:

      import torchvision.models as models
      alexnet = models.alexnet()
      print(alexnet._modules)
      

      Esto mostrará un diccionario que contiene los distintos submódulos de AlexNet:

      OrderedDict([('features', Sequential(
        (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
        (1): ReLU(inplace=True)
        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (4): ReLU(inplace=True)
        (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): ReLU(inplace=True)
        (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (9): ReLU(inplace=True)
        (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      )), ('avgpool', AdaptiveAvgPool2d(output_size=(6, 6))), ('classifier', Sequential(
        (0): Dropout(p=0.5, inplace=False)
        (1): Linear(in_features=9216, out_features=4096, bias=True)
        (2): ReLU(inplace=True)
        (3): Dropout(p=0.5, inplace=False)
        (4): Linear(in_features=4096, out_features=4096, bias=True)
        (5): ReLU(inplace=True)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      ))])
      
  2. alexnet.module:

    • alexnet.module suele ser un atributo que se utiliza para hacer referencia a todo el modelo AlexNet, especialmente cuando se entrena en entornos distribuidos o con múltiples GPU. En este caso, alexnet.module es un contenedor de nivel superior para el modelo y el modelo AlexNet real se encuentra dentro de él. Este enfoque permite mover modelos a diferentes GPU para entrenamiento en paralelo.

    • Por ejemplo:

      import torchvision.models as models
      alexnet = models.alexnet()
      print(alexnet.modules)
      

      El resultado es el siguiente:

      <bound method Module.modules of AlexNet(
        (features): Sequential(
          (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
          (1): ReLU(inplace=True)
          (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
          (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
          (4): ReLU(inplace=True)
          (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
          (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (7): ReLU(inplace=True)
          (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (9): ReLU(inplace=True)
          (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (11): ReLU(inplace=True)
          (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
        (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
        (classifier): Sequential(
          (0): Dropout(p=0.5, inplace=False)
          (1): Linear(in_features=9216, out_features=4096, bias=True)
          (2): ReLU(inplace=True)
          (3): Dropout(p=0.5, inplace=False)
          (4): Linear(in_features=4096, out_features=4096, bias=True)
          (5): ReLU(inplace=True)
          (6): Linear(in_features=4096, out_features=1000, bias=True)
        )
      )>
      

      Hay otro uso común:

      import torch.nn as nn
      import torch.nn.parallel
      import torchvision.models as models
      
      alexnet = models.alexnet()
      # 如果使用多GPU训练,通常有一个外部的包装模块
      alexnet = nn.DataParallel(alexnet)
      
      # 在这种情况下,实际的AlexNet模型在alexnet.module中
      actual_alexnet = alexnet.module
      

En resumen, alexnet._modules contiene varios submódulos de AlexNet, mientras que alexnet.module es un contenedor de modelos que generalmente se usa para el entrenamiento de múltiples GPU. Dependiendo de nuestro uso concreto podremos optar por utilizar uno u otro. Si solo queremos ver submódulos de AlexNet, alexnet._modules es una opción más apropiada. Si necesitamos realizar un entrenamiento con múltiples GPU, entonces alexnet.module podría ser más útil.

[Ampliar conocimientos] 3. torchvision.utils.make_gridfunciones, parámetros y valores de retorno

torchvision.utils.make_grides una función en la biblioteca torchvision de PyTorch que se utiliza para fusionar varias imágenes en una cuadrícula grande para facilitar la visualización y presentación. Esta función se utiliza normalmente para visualizar la salida o los datos de entrenamiento de un modelo de aprendizaje profundo.

Los siguientes son los parámetros y valores de retorno de la make_grid función:

  • parámetro:

    • tensor (Tensor): tensor de entrada que contiene las imágenes que se fusionarán en una cuadrícula. Normalmente, se trata de un tensor de forma (batch_size, channels, height, width) donde batch_size representa el número de imágenes del lote y channels representa el número de canales. height y width representan la altura y el ancho de cada imagen.
    • nrow(Opcional): especifique la cantidad de imágenes que se mostrarán por línea; el valor predeterminado es 8.

    • padding(opcional): número de píxeles de relleno entre cada imagen; el valor predeterminado es 2.

    • normalize(opcional): un valor booleano que indica si el tensor de entrada está normalizado para mostrarse en el rango de 0 a 1; el valor predeterminado es Falso. Si se establece en Verdadero, el tensor de entrada se normalizará a [0, 1] para facilitar la visualización en una cuadrícula.

    • range (opcional): una tupla de longitud 2 que especifica el rango de datos del tensor, por ejemplo (min_value, max_value). Si normalize se establece en Verdadero, este parámetro se puede utilizar para especificar el rango de normalización de los datos.

    • scale_each(Opcional): un valor booleano que indica si cada imagen se escala individualmente para ajustarse a su extensión. Si se establece en Verdadero, la extensión de cada imagen se calculará de forma independiente. El valor predeterminado es Falso.

    • pad_value(opcional): el color del valor de relleno, normalmente una tupla de color RGB de longitud 3, por defecto 0, que significa negro.

  • valor de retorno:

    • grid_image (Tensor): un tensor que fusiona todas las imágenes en el tensor de entrada, con forma (C, H, W), donde C es el número de canales, es el ancho de la cuadrícula. Normalmente, este tensor se puede utilizar directamente para visualización o guardar como un archivo de imagen. H es la altura de la cuadrícula, W

Uso de ejemplo:

import torchvision.utils as vutils
import torch

# 假设有一个存储图像的张量 images
grid_image = vutils.make_grid(images, nrow=4, padding=2, normalize=True)

En el ejemplo anterior, make_grid fusiona las imágenes en el tensor images en una cuadrícula que muestra 4 imágenes por fila, con 2 píxeles llenos, y normalizado para visualización. La imagen fusionada se almacena en la variable grid_image.

[Ampliar conocimientos] 4. Las funciones, parámetros y valores de retorno de writer.add_image

writer.add_image es un método en el objeto SummaryWriter de PyTorch que agrega una imagen a un registro de TensorBoard para visualizar, monitorear y analizar datos de imágenes en TensorBoard. Los siguientes son los parámetros y valores de retorno del método writer.add_image:

  • efecto:

    • writer.add_imageSe utiliza para registrar datos de imágenes para su visualización en TensorBoard. Esto es útil para visualizar datos de imágenes, resultados de modelos o pasos de procesamiento de datos en el aprendizaje profundo.
  • parámetro:

    • tag(Cadena): etiqueta o identificador utilizado para identificar la imagen grabada. Normalmente, se trata de una cadena que describe la imagen con el fin de identificar y organizar registros de imágenes en TensorBoard.

    • img_tensor (Tensor): Tensor que contiene los datos de la imagen a grabar. Normalmente, se trata de un tensor tridimensional de forma (C, H, W), donde C representa el número de canales, H representa la altura, Contiene los valores de píxeles de la imagen que se van a registrar. W representa el ancho. img_tensor

    • global_step(entero, opcional): representa el número de pasos o iteraciones globales para la grabación, que se utilizan para alinear imágenes de diferentes registros en TensorBoard. Si no se proporciona, TensorBoard utilizará pasos de incremento automático. ——En el entrenamiento de modelos, suele ser Época

  • valor de retorno:

    • None: El método writer.add_image no tiene valor de retorno. Se utiliza principalmente para registrar datos de imágenes en TensorBoard para su visualización y análisis.
  • Uso de ejemplo:

    from torch.utils.tensorboard import SummaryWriter
    import torch
    
    
    # 创建一个 SummaryWriter 对象,用于记录数据到 TensorBoard
    writer = SummaryWriter()
    
    # 假设我们已经创建了图像张量 img_tensor 和一个全局步骤 global_step
    # 将图像记录到 TensorBoard 中
    writer.add_image("Sample Image", img_tensor, global_step)
    
    # 关闭 SummaryWriter
    writer.close()
    

    En el ejemplo anterior, el método add_image agrega img_tensor al registro de TensorBoard e identifica la imagen con la etiqueta "Imagen de muestra". Podemos ver y analizar estos registros de imágenes en TensorBoard. El parámetro global_step se utiliza para especificar el paso global del registro (generalmente pensamos en global_step como Época o Iteración) para alinear la imagen con otros registros en TensorBoard.

3. CAM (Mapa de activación de clase, mapa de activación de clase) y Grad-CAM

se introdujo en 1.2.4 nn.module.register_backward_hook. De hecho, esta función de enlace se usa a menudo en CAM para obtener HeatMap.

Insertar descripción de la imagen aquí

CAM tiene una desventaja: la parte de salida final de la red debe tener GAP (Global Average Pooling, agrupación promedio global) para obtener los pesos de diferentes mapas de características y así obtener el HeatMap, por lo que para poder utilizar CAM se debe modificar la estructura de la red, por lo que su ámbito de aplicación no es tan amplio. Ante las deficiencias de CAM, se propuso un nuevo método: Grad-CAM.

Insertar descripción de la imagen aquí

Fig 2: Descripción general de Grad-CAM: Dada una imagen y una clase de interés (por ejemplo, “tiger-cat” o cualquier otro tipo de salida diferenciable) como entrada, pasamos por la parte CNN del modelo y luego se calcula la tarea. para obtener la puntuación bruta para esa categoría. El gradiente se establece en cero para todas las clases excepto la clase deseada (gato tigre), que se establece en 1 para la clase deseada. Luego, esta señal se propaga hacia atrás a los mapas de características convolucionales rectificados de interés, que se combinan para calcular una localización Grad-CAM aproximada (mapa de calor azul), que indica qué partes debe observar el modelo para tomar una decisión específica. Finalmente, multiplicamos el mapa de calor con retropropagación guiada para obtener una visualización Guided Grad-CAM de alta resolución y específica del concepto.

CAM tiene dos puntos clave: ① mapa de características; ② peso correspondiente al mapa de características, mientras que Grad-CAM usa gradiente como peso del mapa de características.


A continuación se utiliza un LeNet-5 para demostrar la aplicación de backward_hook en Grad-CAM. El flujo de código es el siguiente:

  1. 创网络net
  2. La función Registrar forward_hook se utiliza para extraer la última capa de mapas de características;
  3. Registrar backward_hook La función se utiliza para extraer el gradiente del vector de clase (one-hot) en el mapa de características;
  4. Promediar el gradiente del mapa de características y ponderar el mapa de características;
  5. Visualice el mapa de calor.
"""
通过实现 Grad-CAM 学习 module 中的 forward_hook 和 backward_hook 函数
"""
import cv2
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool1(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def img_transform(img_in, transform):
    """
    将img进行预处理,并转换成模型输入所需的形式—— B*C*H*W
    :param img_roi: np.array
    :return:
    """
    img = img_in.copy()
    img = Image.fromarray(np.uint8(img))
    img = transform(img)
    img = img.unsqueeze(0)    # C*H*W --> B*C*H*W
    return img


def img_preprocess(img_in):
    """
    读取图片,转为模型可读的形式
    :param img_in: ndarray, [H, W, C]
    :return: PIL.image
    """
    img = img_in.copy()
    img = cv2.resize(img, (32, 32))
    img = img[:, :, ::-1]   # BGR --> RGB
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4948052, 0.48568845, 0.44682974], [0.24580306, 0.24236229, 0.2603115])
    ])
    img_input = img_transform(img, transform)
    return img_input


def backward_hook(module, grad_in, grad_out):
    grad_block.append(grad_out[0].detach())


def farward_hook(m, i, o):
    fmap_block.append(o)


def show_cam_on_image(img, mask, out_dir):
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)

    path_cam_img = os.path.join(out_dir, "cam.jpg")
    path_raw_img = os.path.join(out_dir, "raw.jpg")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    cv2.imwrite(path_cam_img, np.uint8(255 * cam))
    cv2.imwrite(path_raw_img, np.uint8(255 * img))


def comp_class_vec(ouput_vec, index=None):
    """
    计算类向量
    :param ouput_vec: tensor
    :param index: int,指定类别
    :return: tensor
    """
    if not index:
        index = np.argmax(ouput_vec.cpu().data.numpy())
    else:
        index = np.array(index)
    index = index[np.newaxis, np.newaxis]
    index = torch.from_numpy(index)
    one_hot = torch.zeros(1, 10).scatter_(1, index, 1)
    one_hot.requires_grad = True
    class_vec = torch.sum(one_hot * output)  # one_hot = 11.8605

    return class_vec


def gen_cam(feature_map, grads):
    """
    依据梯度和特征图,生成cam
    :param feature_map: np.array, in [C, H, W]
    :param grads: np.array, in [C, H, W]
    :return: np.array, [H, W]
    """
    cam = np.zeros(feature_map.shape[1:], dtype=np.float32)  # cam shape (H, W)

    weights = np.mean(grads, axis=(1, 2))  #

    for i, w in enumerate(weights):
        cam += w * feature_map[i, :, :]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (32, 32))
    cam -= np.min(cam)
    cam /= np.max(cam)

    return cam


if __name__ == '__main__':

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    path_img = os.path.join("cam_img", "test_img_8.png")
    path_net = os.path.join("net_params_72p.pkl")
    output_dir = os.path.join("Result", "backward_hook_cam")

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    fmap_block = list()
    grad_block = list()
    print(path_img)

    # 图片读取;网络加载
    img = cv2.imread(path_img, 1)  # H*W*C
    img_input = img_preprocess(img)
    net = Net()
    net.load_state_dict(torch.load(path_net))

    # 注册hook
    net.conv2.register_forward_hook(farward_hook)
    net.conv2.register_backward_hook(backward_hook)

    # forward
    output = net(img_input)
    idx = np.argmax(output.cpu().data.numpy())
    print("predict: {}".format(classes[idx]))

    # backward
    net.zero_grad()
    class_loss = comp_class_vec(output)
    class_loss.backward()

    # 生成cam
    grads_val = grad_block[0].cpu().data.numpy().squeeze()
    fmap = fmap_block[0].cpu().data.numpy().squeeze()
    cam = gen_cam(fmap, grads_val)

    # 保存cam图片
    img_show = np.float32(cv2.resize(img, (32, 32))) / 255
    show_cam_on_image(img_show, cam, output_dir)

Cabe señalar que en la función backward_hook, grad_out es un tipo de tupla. Para obtener el gradiente del mapa de características, debe hacer esto grad_block. append(grad_out[0]. detach())

Aquí observamos el HeatMap de 3 imágenes de aviones, como se muestra en la siguiente figura: la primera fila es la imagen original y la segunda fila es la imagen con el HeatMap superpuesto.

Insertar descripción de la imagen aquí

Aquí se descubrió un fenómeno interesante: el modelo juzgó la imagen como un avión basándose en el cielo azul, no como el avión (Figura 1 ~ Figura 3). Entonces, si le damos al modelo una imagen azul cielo puro, ¿qué juzgará el modelo? Como se muestra en la Figura 4, el modelo de descubrimiento determinó que la imagen era un avión.

A partir de aquí se descubrió que aunque el modelo podía clasificar correctamente la aeronave, lo que aprendía no eran las características de la misma. Esto conduce a una reducción significativa en el rendimiento de generalización del modelo. Aquí podemos considerar usar un truco para obligar al modelo a aprender el avión en lugar del cielo azul que a menudo aparece con el avión, o ajustar los datos.


Pregunta sobre la Figura 4: ¿El área azul de HeatMap no tiene ningún efecto en la imagen? ¿Se puede juzgar la imagen sólo por el área roja?

Insertar descripción de la imagen aquí

A continuación, se superpone una imagen de automóvil correctamente clasificada (Figura 5) en el área de respuesta azul de la Figura 4 (es decir, el área a la que el modelo no presta atención). Los resultados se muestran en la Figura 6. El valor de respuesta de la pieza del automóvil es muy pequeño, el modelo aún identifica la imagen como un avión a través del área azul cielo. Luego, el automóvil se superpone en el área de respuesta roja en la Figura 4 (esquina inferior derecha de la figura), y el resultado se muestra en la Figura 7. La imagen todavía está clasificada como un avión.
Lo interesante es que el automóvil está superpuesto en el área de respuesta roja en la Figura 7. El modelo determina que la imagen es un barco, y el área de respuesta roja es la parte inferior del área azul. lo cual concuerda con la posición del barco en el mar, muy cerca.

Aprenda el uso de backward_hook a través del código anterior y su aplicación en Grad-CAM, y use Grad-CAM para diagnosticar si el modelo ha aprendido características clave.

fuente de conocimiento

  1. 【05-05-Función de gancho y algoritmo CAM.mp4】
  2. Ganchos de PyTorch y su aplicación en Grad-CAM_PyTorch define ganchos para observar el blog de grad-CSDN

Supongo que te gusta

Origin blog.csdn.net/weixin_44878336/article/details/133859089
Recomendado
Clasificación