El código pytorch implementa GAM del mecanismo de atención.

Mecanismo de atención GAM

Mecanismo de atención GAM: un mecanismo de atención "global" que abarca dimensiones de canales espaciales y retiene información para amplificar las interacciones interdimensionales "globales".

Dirección del artículo: Mecanismo de atención global: retener información para mejorar las interacciones espaciales-canal
Diagrama esquemático de la estructura GAM

Submódulo de atención del canal

El submódulo de atención del canal utiliza una disposición tridimensional para preservar la información en tres dimensiones. Luego, utiliza un MLP (perceptrón multicapa) de dos capas para amplificar la dependencia interdimensional del espacio-canal. (MLP es una estructura codificador-decodificador, igual que BAM, y su relación de compresión es r); el submódulo de atención del canal se muestra en la figura:
Diagrama de estructura del submódulo de atención del canal

Submódulo de atención espacial

En el submódulo de atención espacial, para centrarse en la información espacial, se utilizan dos capas convolucionales para la fusión de información espacial. La misma relación de reducción r que BAM también se utiliza desde el submódulo de atención del canal. Al mismo tiempo, la operación de agrupación máxima tiene un impacto negativo porque reduce el uso de información. La operación de agrupación se elimina aquí para preservar aún más el mapeo de características. Por lo tanto, el módulo de atención espacial a veces aumenta significativamente el número de parámetros. Para evitar que los parámetros aumenten significativamente, en ResNet50 se adopta la convolución de grupo con Channel Shuffle. El submódulo de atención espacial sin convolución de grupo se muestra en la figura:
Diagrama estructural del submódulo de atención espacial.

Código:

import torch.nn as nn
import torch
 
class GAM_Attention(nn.Module):
    def __init__(self, in_channels, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels)
        )
 
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
 
        x = x * x_channel_att
 
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    x = torch.randn(1, 64, 20, 20)
    b, c, h, w = x.shape
    net = GAM_Attention(in_channels=c)
    y = net(x)
    print(y.size())

Supongo que te gusta

Origin blog.csdn.net/DM_zx/article/details/132707926
Recomendado
Clasificación