El código pytorch implementa el mecanismo de atención ShuffleAttention

ShuffleAttention mecanismo de atención

El mecanismo de atención en la CNN actual incluye principalmente: atención de canal y atención espacial.Algunos métodos actuales (GCNet, CBAM, etc.) suelen integrar los dos, lo que es propenso a problemas de dificultad convergente y carga de cálculo pesada. Aunque ECANet y SGE propusieron algunos esquemas de optimización, no aprovecharon al máximo la relación entre el canal y el espacio. Por lo tanto, el autor se hace una pregunta "¿Se pueden fusionar diferentes módulos de atención de una manera más ligera pero más eficiente?"

Para resolver este problema, el autor propuso la atención aleatoria, el marco general se muestra en la siguiente figura. Se puede ver que las características de entrada se dividen primero en g grupos, y luego las características de cada grupo se dividen y se dividen en dos ramas para calcular la atención del canal y la atención espacial, respectivamente. Ambas atenciones se calculan utilizando el método de conexión completa + sigmoide. . Luego, los resultados de las dos ramas se empalman y luego se fusionan para obtener un mapa de características con el mismo tamaño que la entrada. Finalmente, se utiliza una capa aleatoria para el procesamiento.
esquema estructural

el código se muestra a continuación:

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter

# https://arxiv.org/pdf/2102.00240.pdf
class ShuffleAttention(nn.Module):

    def __init__(self, channel=512,reduction=16,G=8):
        super().__init__()
        self.G=G
        self.channel=channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid=nn.Sigmoid()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)


    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        #group into subfeatures
        x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w

        #channel_split
        x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w

        #channel attention
        x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
        x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1
        x_channel=x_0*self.sigmoid(x_channel)

        #spatial attention
        x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
        x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
        x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w

        # concatenate along channel axis
        out=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,w
        out=out.contiguous().view(b,-1,h,w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out

Supongo que te gusta

Origin blog.csdn.net/DM_zx/article/details/132302039
Recomendado
Clasificación