简单理解反向注意力(Reverse Attention)机制

反向注意力(Reverse Attention)机制由《Reverse Attention for Salient Object Detection》一文提出。其核心思想为,在显著目标检测(二分割)网络中,对象的大致全局位置信息在网络的深层便可以获得,因此Decoder的浅层部分只需要关注对象的局部细节即可。具体做法则是,将decoder深层的输出给取反,那么网络关注的位置即为对象以外的边缘部分,从而使得最终结果局部细节更加出色。

Reverse Attention的结构如下图所示:
在这里插入图片描述

代码(取自原文github仓库)如下:

class RA(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(RA, self).__init__()
        self.convert = nn.Conv2d(in_channel, out_channel, 1)
        self.convs = nn.Sequential(
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, 1, 3, padding=1),
        )
        self.channel = out_channel
	
	# x:待被施加空间注意力的浅层特征
	# y:用于计算reverse attention map的深层特征
    def forward(self, x, y):
        a = torch.sigmoid(-y)	# reverse并压缩至0~1区间内以用作空间注意力map
        x = self.convert(x)		# 统一x, y通道数
        x = a.expand(-1, self.channel, -1, -1).mul(x)	# x, y相乘,完成空间注意力
        y = y + self.convs(x)	# 残差连接(图中未画出)
        return y

猜你喜欢

转载自blog.csdn.net/qq_40714949/article/details/129014949