リバース アテンション メカニズムは、「顕著なオブジェクト検出のためのリバース アテンション」という記事で提案されました。中心的なアイデアは、顕著なオブジェクト検出 (2 セグメンテーション) ネットワークでは、ネットワークの深い層でオブジェクトのおおよそのグローバル位置情報を取得できるため、デコーダーの浅い部分は、オブジェクトのローカルの詳細。具体的な方法は、デコーダのディープ出力を反転して、ネットワークが注目する位置がオブジェクト以外のエッジ部分になるようにして、最終結果の局所的な詳細がより良くなるようにすることです。
リバース アテンションの構造を次の図に示します。
コード (元のgithubウェアハウスから取得) は次のとおりです。
class RA(nn.Module):
def __init__(self, in_channel, out_channel):
super(RA, self).__init__()
self.convert = nn.Conv2d(in_channel, out_channel, 1)
self.convs = nn.Sequential(
nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
nn.Conv2d(out_channel, 1, 3, padding=1),
)
self.channel = out_channel
# x:待被施加空间注意力的浅层特征
# y:用于计算reverse attention map的深层特征
def forward(self, x, y):
a = torch.sigmoid(-y) # reverse并压缩至0~1区间内以用作空间注意力map
x = self.convert(x) # 统一x, y通道数
x = a.expand(-1, self.channel, -1, -1).mul(x) # x, y相乘,完成空间注意力
y = y + self.convs(x) # 残差连接(图中未画出)
return y