Comprensión simple del mecanismo de atención inversa (Atención inversa)

El mecanismo de atención inversa fue propuesto por el artículo "Atención inversa para la detección de objetos salientes". La idea central es que en la red de detección de objetos salientes (dos segmentos), la información de posición global aproximada del objeto se puede obtener en la capa profunda de la red, por lo que la parte superficial del decodificador solo necesita prestar atención a la detalles locales del objeto. El método específico es invertir la salida profunda del decodificador, de modo que la posición a la que presta atención la red sea la parte del borde que no sea el objeto, de modo que los detalles locales del resultado final sean mejores.

La estructura de Atención Inversa se muestra en la siguiente figura:
inserte la descripción de la imagen aquí

El código (tomado del almacén original de github ) es el siguiente:

class RA(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(RA, self).__init__()
        self.convert = nn.Conv2d(in_channel, out_channel, 1)
        self.convs = nn.Sequential(
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, 1, 3, padding=1),
        )
        self.channel = out_channel
	
	# x:待被施加空间注意力的浅层特征
	# y:用于计算reverse attention map的深层特征
    def forward(self, x, y):
        a = torch.sigmoid(-y)	# reverse并压缩至0~1区间内以用作空间注意力map
        x = self.convert(x)		# 统一x, y通道数
        x = a.expand(-1, self.channel, -1, -1).mul(x)	# x, y相乘,完成空间注意力
        y = y + self.convs(x)	# 残差连接(图中未画出)
        return y

Supongo que te gusta

Origin blog.csdn.net/qq_40714949/article/details/129014949
Recomendado
Clasificación