le code pytorch implémente EffectiveSE du mécanisme d'attention

EfficaceSEAttention

Le nom complet du mécanisme d'attention EffectiveSE est Effective Squeeze and Extraction, un module d'attention plug-and-play, qui est amélioré sur la base de SE (Squeeze and Extraction). La différence avec SE est que le mécanisme d'attention EffectiveSE n'a qu'une seule couche entièrement connectée. L'auteur de "CenterMask: Real-Time Anchor-Free Instance Segmentation" a remarqué que le module SE a une lacune: la perte d'informations de canal due à la réduction de dimensions. Afin d'éviter la charge de calcul d'un modèle aussi grand, les deux couches entièrement connectées de se doivent réduire la dimension du canal. En particulier, lorsque la première couche entièrement connectée utilise r pour réduire le canal de caractéristique d'entrée et changer le nombre de canaux de c à c/r, la deuxième couche entièrement connectée doit étendre le nombre réduit de canaux au canal d'origine c. Dans ce processus, la réduction de la dimension du canal conduit à la perte d'informations de canal. Par conséquent, le mécanisme d'attention EffectiveSE utilise uniquement une couche entièrement connectée avec un numéro de canal c au lieu de deux couches entièrement connectées, évitant ainsi la perte d'informations de canal.

Adresse papier : https://arxiv.org/pdf/1911.06667.pdf

Schéma structurel

code afficher comme ci-dessous:

import torch
from torch import nn as nn
from timm.models.layers.create_act import create_act_layer

class EffectiveSEModule(nn.Module):
    def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):
        super(EffectiveSEModule, self).__init__()
        self.add_maxpool = add_maxpool
        self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
        self.gate = create_act_layer(gate_layer)

    def forward(self, x):
        x_se = x.mean((2, 3), keepdim=True)
        if self.add_maxpool:
            # experimental codepath, may remove or change
            x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
        x_se = self.fc(x_se)
        return x * self.gate(x_se)

if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    Ese = EffectiveSEModule(512)
    output=Ese(input)
    print(output.shape)

Je suppose que tu aimes

Origine blog.csdn.net/DM_zx/article/details/132321429
conseillé
Classement