A2-Nets: Double Attention Networks

1. Paper

A2-Nets: Double Attention Networks

2. Overview

在这里插入图片描述

3. Code
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F

class DoubleAttention(nn.Module):
    def __init__(self, in_channels,c_m,c_n,reconstruct = True):
        super().__init__()
        self.in_channels=in_channels
        self.reconstruct = reconstruct
        self.c_m=c_m
        self.c_n=c_n
        self.convA=nn.Conv2d(in_channels,c_m,1)
        self.convB=nn.Conv2d(in_channels,c_n,1)
        self.convV=nn.Conv2d(in_channels,c_n,1)
        if self.reconstruct:
            self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size = 1)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h,w=x.shape
        assert c==self.in_channels
        A=self.convA(x) # b,c_m,h,w
        B=self.convB(x) # b,c_n,h,w
        V=self.convV(x) # b,c_n,h,w
        tmpA=A.view(b,self.c_m,-1)
        attention_maps=F.softmax(B.view(b,self.c_n,-1))
        attention_vectors=F.softmax(V.view(b,self.c_n,-1))
        # 第一步:特征提取
        global_descriptors=torch.bmm(tmpA,attention_maps.permute(0,2,1))  # b.c_m,c_n
        # 第二步:特征分布
        tmpZ = global_descriptors.matmul(attention_vectors) #b,c_m,h*w
        tmpZ=tmpZ.view(b,self.c_m,h,w)  # b,c_m,h,w
        if self.reconstruct:
            tmpZ=self.conv_reconstruct(tmpZ)
        return tmpZ 

if __name__ == '__main__':
    input=torch.randn(100,512,6,6)
    A2 = DoubleAttention(512,128,128,True)
    output=A2(input)
    print(output.shape)
    # torch.Size([100, 512, 6, 6])
4. Source
  • 清华大学

猜你喜欢

转载自blog.csdn.net/weixin_46398647/article/details/124002202