[Serie de mecanismos de atención de aprendizaje profundo] - Mecanismo de atención de SENet (con implementación de pytorch)

El mecanismo de atención (Attention Mechanism) en el aprendizaje profundo es un método que imita el sistema visual y cognitivo humano, lo que permite que la red neuronal se centre en las partes relevantes al procesar los datos de entrada. Al introducir el mecanismo de atención, la red neuronal puede aprender automáticamente y enfocarse selectivamente en información importante en la entrada, mejorando el rendimiento y la capacidad de generalización del modelo.

El mecanismo de atención introducido por la red neuronal convolucional tiene principalmente los siguientes métodos:

  • Agregar mecanismo de atención en la dimensión espacial
  • Agregar mecanismo de atención en la dimensión del canal
  • Agregar mecanismo de atención en la dimensión mixta de ambos

Explicaremos varios mecanismos de atención en esta serie y usaremos pytorch para implementarlos.Hoy explicaremos el mecanismo de atención de SENet

El mecanismo de atención SENet (Squeeze-and-Excitation Networks) introduce el mecanismo de atención en la dimensión del canal. Su idea central es aprender los pesos de características a través de la red de acuerdo con la pérdida, de modo que el mapa de características efectivo tenga un gran peso y el mapa de características de efecto no válido o pequeño pesotiene La incorporación del bloque SE en algunas redes de clasificación originales inevitablemente aumenta algunos parámetros y cálculos, pero sigue siendo aceptable en términos de efectos. El bloque Sequeeze-and-Excitation (SE) no es una estructura de red completa, sino una subestructura que puede integrarse en otros modelos de clasificación o detección.

inserte la descripción de la imagen aquí

Lo anterior es un diagrama esquemático de la estructura de SENet, y sus operaciones clave son compresión y excitación.La importancia del mapa de características en cada canal se obtiene a través del aprendizaje automático, para asignar diferentes pesos a diferentes canales y mejorar la contribución de canales útiles.

Mecanismo de implementación:

  1. Squeeze: comprima la gran característica bidimensional (h*w) de cada canal en un número real a través de la capa de agrupación promedio de toda la obra, y la dimensión cambia: (C, H, W) -> (C, 1, 1)
  2. Excitación: asigne a cada canal un peso de característica y luego integre y extraiga información a través de dos capas completamente conectadas para construir la autocorrelación entre canales. El número de pesos de salida es consistente con el número de canales del mapa de características, y la dimensión cambia: (C, 1, 1) -> (C, 1, 1)
  3. Escala: Pesar los pesos normalizados en las características de cada canal. En el artículo, se utiliza ponderación por multiplicación y la dimensión cambia: (C, H, W) * (C, 1, 1) -> (C, H, W )

implementación de pytorch:

class SENet(nn.Module):
    def __init__(self, in_channels, ratio=16):
        super(SENet, self).__init__()
        self.in_channels = in_channels
        self.fgp = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(self.in_channels, int(self.in_channels / ratio), bias=False)
        self.act1 = nn.ReLU()
        self.fc2 = nn.Linear(int(self.in_channels / ratio), self.in_channels, bias=False)
        self.act2 = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        output = self.fgp(x)
        output = output.view(b, c)
        output = self.fc1(output)
        output = self.act1(output)
        output = self.fc2(output)
        output = self.act2(output)
        output = output.view(b, c, 1, 1)
        return torch.multiply(x, output)

Supongo que te gusta

Origin blog.csdn.net/qq_43456016/article/details/132170807
Recomendado
Clasificación