Pytorch code implements SGE of attention mechanism

SGE attention mechanism

The SGE attention mechanism is a lightweight attention module. Its highlight is that it can achieve strong gains in classification and detection performance without increasing the amount of parameters and calculations. At the same time, compared with other attention modules, it is the first generation source that uses the similarity between local and global as the attention mask, and has a very strong semantic representation to enhance the interpretability.
The SGE attention module generates attention factors in each group, so that the importance of each sub feature can be obtained, and each group can also learn and suppress noise in a targeted manner. This attention factor is only determined by the similarity between global and local features within each group, so SGE is very lightweight. After training, it was found that SGE is very effective for some high-level semantics.

Paper address: https://arxiv.org/pdf/1905.09646.pdf
![Structural schematic diagram](https://img-blog.csdnimg.cn/cb33e483e7134516a417195a51eaf6f5.png)

code show as below:

import numpy as np
import torch
from torch import nn
from torch.nn import init

class SpatialGroupEnhance(nn.Module):
    def __init__(self, groups=8):
        super().__init__()
        self.groups=groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight=nn.Parameter(torch.zeros(1,groups,1,1))
        self.bias=nn.Parameter(torch.zeros(1,groups,1,1))
        self.sig=nn.Sigmoid()
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h,w=x.shape
        x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w
        xn=x*self.avg_pool(x) #bs*g,dim//g,h,w
        xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w
        t=xn.view(b*self.groups,-1) #bs*g,h*w

        t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w
        std=t.std(dim=1,keepdim=True)+1e-5
        t=t/std #bs*g,h*w
        t=t.view(b,self.groups,h,w) #bs,g,h*w
        
        t=t*self.weight+self.bias #bs,g,h*w
        t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w
        x=x*self.sig(t)
        x=x.view(b,c,h,w)
        return x 

if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    sge = SpatialGroupEnhance(groups=8)
    output=sge(input)
    print(output.shape)

Guess you like

Origin blog.csdn.net/DM_zx/article/details/132418385