Dual Attention Network for Scene Segmentation的两个注意力模块

理论说明

其他方法的不足

  1. 之前的方法使用多层特征融合、LSTM、graph来获取特征依赖的方法效率低

  2. 如果嵌入的上下文是已经探索过的(就是经过了不少卷积层吧),重要的、显眼的物体特征影响会不显眼的物体的特征,从而影响识别

    在卷积,池化的过程中,不显眼的特征逐渐被显眼的特征取代,所以下手的话要在最开始的地方下手?

DANet优点

  1. 作者的方法能有选择性的融合不起眼的物体的相似特征,让这个特征更明显,以此来避免显眼的物体的影响

    不起眼的物体的单独的力量是渺小的,可能也不突出,但是经过全部特征加权求和之后,所有的不起眼的物体的特征都会变得显眼。
    但是,那些salient物体的特征也会更显眼啊?如此的话环境的影响似乎变淡 了,因为相似的物体少
    不同的特征都在不同的channel上,每个通道管一些特征,加注意力后这些特征会更明显?

  2. 检测不同的尺度需要的特征是一样的,作者的方法从全局的角度适应性的融合任何scale的相似特征

    作者的网络中没有类似多尺度的东西,只是用两个注意力模组就获得了很好的结果,到底为什么还不知道


方法说明

在这里插入图片描述

作者将这两个注意力模块加到了处理场景分割的网络中,网络流程为

  1. 图片先经过RseNet处理,在最后两个ResNet block中,移除下采样,使用空洞卷积,这样就能保留更多细节,并且不增加额外参数
  2. 然后将这些特征送到两个平行的注意力模组中处理,将两个注意力模组的内容融合,进一步提高特征的表示,来得到更精确的结果
  3. 融合的特征送入膨胀FCN中,整个网络结束。

两个注意力模块分别在位置通道上起作用,注意力模块的操作分三步进行:

  1. 生成一个通道(位置)注意力矩阵来给特征的任意两个通道(位置)的空间关系建模
  2. 然后让原来的特征和注意力矩阵做矩阵乘法
  3. 残差(加上原本的特征x)

通道可以看作特定类的响应。

注意力模组公式

设定输入到注意力模组的特征是 A R C × H × W A \in R^{C\times H \times W}

x j i = e x p ( A i A j ) i = 1 C e x p ( A i A j ) x_{ji} = \frac{exp(A_i \cdot A_j)}{\sum^C_{i=1} exp(A_i \cdot A_j)}

X R C × C X \in R^{C \times C} ,存放通道之间的相关系数

E j = β i = 1 C ( x j i A i ) + A j E_j = \beta \sum^C_{i=1}(x_{ji}A_i) + A_j

之后对两个attention做卷积,然后再逐元素相加。


代码实现

和non-local略有不同:

  1. non-local的 θ , ϕ , g \theta,\phi,g 都将通道减小到了一般,然后再通过 W W 来改变通道数,使得和输入的通道数相同
  2. DANet中 θ , ϕ \theta,\phi 的通道数都会减小,而 g g 中的通道数没有减小,这样不需要用 W W 就能和输入的通道是一样,同时这里也把 W W 换成了一个可学习的系数 γ \gamma
import torch
from torch import nn


class PAM(nn.Module):
    def __init__(self,input_dim):
        super(PAM,self).__init__()
        self.input_dim = input_dim
        self.inter_dim = input_dim // 8
        if self.inter_dim == 0:
            self.inter_dim = 1

        self.theta = nn.Conv2d(self.input_dim,self.inter_dim,kernel_size=1)
        self.phi = nn.Conv2d(self.input_dim,self.inter_dim,kernel_size=1)
        self.g = nn.Conv2d(self.input_dim,self.input_dim,kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        batch, channel, H, W = x.size()
        theta_x = self.theta(x).view(batch, H*W, -1)
        phi_x = self.phi(x).view(batch, -1, H*W)
        f = torch.bmm(theta_x, phi_x)
        g = self.g(x).view(batch, H*W, -1)
        attention = torch.softmax(f,dim=-1)
        y = torch.bmm(attention,g).view(batch,-1,H,W)
        z = self.gamma * y + x
        return z


class CAM(nn.Module):
    def __init__(self,input_dim):
        super(CAM,self).__init__()
        self.input_dim = input_dim
        self.inter_dim = input_dim // 8
        if self.inter_dim == 0:
            self.inter_dim = 1

        self.theta = nn.Conv2d(self.input_dim,self.inter_dim,kernel_size=1)
        self.phi = nn.Conv2d(self.input_dim,self.inter_dim,kernel_size=1)
        self.g = nn.Conv2d(self.input_dim,self.inter_dim,kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        batch, channel, H, W = x.size()
        theta_x = self.theta(x).view(batch, H*W, -1)
        phi_x = self.phi(x).view(batch, -1, H*W)
        f = torch.bmm(theta_x, phi_x)
        # 下面这一句不知道有什么用,加上了耗时更久
        f = torch.max(f,dim=-1,keepdim=True)[0].expand_as(f) - f
        g = self.g(x).view(batch, H*W, -1)
        attention = torch.softmax(f,dim=-1)
        y = torch.bmm(attention,g).view(batch,-1,H,W)
        z = self.gamma * y + x
        return z

class DANetHead(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DANetHead, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.inter_channel = in_channel // 4
        if self.inter_channel == 0:
            self.inter_channel = 1

        self.conv_res_p = nn.Sequential(nn.Conv2d(self.in_channel, self.inter_channel, kernel_size=3, padding=1, bias=False),
                                        nn.BatchNorm2d(self.inter_channel),
                                        nn.ReLU())
        self.conv_res_c = nn.Sequential(nn.Conv2d(self.in_channel, self.inter_channel, kernel_size=3, padding=1, bias=False),
                                        nn.BatchNorm2d(self.inter_channel),
                                        nn.ReLU())

        self.PA = PAM(self.inter_channel)
        self.CA = CAM(self.inter_channel)

        self.conv_before_sum_p = nn.Sequential(nn.Conv2d(self.inter_channel,self.inter_channel, kernel_size=3, padding=1, bias=False),
                                             nn.BatchNorm2d(self.inter_channel),
                                             nn.ReLU())
        self.conv_before_sum_c = nn.Sequential(nn.Conv2d(self.inter_channel,self.inter_channel, kernel_size=3, padding=1, bias=False),
                                             nn.BatchNorm2d(self.inter_channel),
                                             nn.ReLU())

        self.conv_after_sum_p = nn.Sequential(nn.Dropout2d(0.1, False),
                                              nn.Conv2d(self.inter_channel, self.out_channel, kernel_size=1, bias=False))
        self.conv_after_sum_c = nn.Sequential(nn.Dropout2d(0.1, False),
                                              nn.Conv2d(self.inter_channel, self.out_channel, kernel_size=1, bias=False))
        self.conv_sum_p_c = nn.Sequential(nn.Dropout2d(0.1, False),
                                              nn.Conv2d(self.inter_channel, self.out_channel, kernel_size=1, bias=False))

    def forward(self, x):
        batch, channel, H, W = x.size()

        feat_4p = self.conv_res_p(x)
        feat_4c = self.conv_res_c(x)

        feat_p = self.PA(feat_4p)
        feat_c = self.CA(feat_4c)

        feat_p_before_sum = self.conv_before_sum_p(feat_p)
        feat_c_before_sum = self.conv_before_sum_c(feat_c)

        feat_sum = feat_p_before_sum + feat_c_before_sum

        p_get_loss = self.conv_after_sum_p(feat_p_before_sum)
        c_get_loss = self.conv_after_sum_c(feat_c_before_sum)
        feat_result = self.conv_sum_p_c(feat_sum)

        output = []
        # 保留前两个算loss用
        output.append(p_get_loss)
        output.append(c_get_loss)
        output.append(feat_result)

        return output


if __name__ == "__main__":
    x = torch.randn((2,8,64,64))
    model = DANetHead(8,1)
    z = model(x)
    print('z size :',z[2].size())
发布了63 篇原创文章 · 获赞 2 · 访问量 8025

猜你喜欢

转载自blog.csdn.net/McEason/article/details/104168171