反向注意力(Reverse Attention)机制由《Reverse Attention for Salient Object Detection》一文提出。其核心思想为,在显著目标检测(二分割)网络中,对象的大致全局位置信息在网络的深层便可以获得,因此Decoder的浅层部分只需要关注对象的局部细节即可。具体做法则是,将decoder深层的输出给取反,那么网络关注的位置即为对象以外的边缘部分,从而使得最终结果局部细节更加出色。
Reverse Attention的结构如下图所示:
代码(取自原文github仓库)如下:
class RA(nn.Module):
def __init__(self, in_channel, out_channel):
super(RA, self).__init__()
self.convert = nn.Conv2d(in_channel, out_channel, 1)
self.convs = nn.Sequential(
nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
nn.Conv2d(out_channel, 1, 3, padding=1),
)
self.channel = out_channel
# x:待被施加空间注意力的浅层特征
# y:用于计算reverse attention map的深层特征
def forward(self, x, y):
a = torch.sigmoid(-y) # reverse并压缩至0~1区间内以用作空间注意力map
x = self.convert(x) # 统一x, y通道数
x = a.expand(-1, self.channel, -1, -1).mul(x) # x, y相乘,完成空间注意力
y = y + self.convs(x) # 残差连接(图中未画出)
return y