El código pytorch implementa el mecanismo de atención EfectivoSE.

AtenciónEfectiva

El nombre completo del mecanismo de atención de EffectiveSE es Eficaz Squeeze and Extraction, un módulo de atención plug-and-play, que se mejora en base a SE (Squeeze and Extraction). La diferencia con SE es que el mecanismo de atención de EffectiveSE tiene solo una capa completamente conectada. El autor de "CenterMask: Real-Time Anchor-Free Instance Segmentation" notó que el módulo SE tiene una deficiencia: la pérdida de información del canal debido a la reducción de dimensiones Para evitar la carga computacional de un modelo tan grande, las dos capas completamente conectadas de se necesitan reducir la dimensión del canal. En particular, cuando la primera capa completamente conectada usa r para reducir el canal de características de entrada y cambia la cantidad de canales de c a c/r, la segunda capa completamente conectada necesita expandir la cantidad reducida de canales al canal original c. En este proceso, la reducción de la dimensión del canal conduce a la pérdida de información del canal. Por lo tanto, el mecanismo de atención de EffectiveSE solo utiliza una capa completamente conectada con un número de canal c en lugar de dos capas completamente conectadas, evitando la pérdida de información del canal.

Dirección en papel: https://arxiv.org/pdf/1911.06667.pdf

esquema estructural

el código se muestra a continuación:

import torch
from torch import nn as nn
from timm.models.layers.create_act import create_act_layer

class EffectiveSEModule(nn.Module):
    def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):
        super(EffectiveSEModule, self).__init__()
        self.add_maxpool = add_maxpool
        self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
        self.gate = create_act_layer(gate_layer)

    def forward(self, x):
        x_se = x.mean((2, 3), keepdim=True)
        if self.add_maxpool:
            # experimental codepath, may remove or change
            x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
        x_se = self.fc(x_se)
        return x * self.gate(x_se)

if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    Ese = EffectiveSEModule(512)
    output=Ese(input)
    print(output.shape)

Supongo que te gusta

Origin blog.csdn.net/DM_zx/article/details/132321429
Recomendado
Clasificación