SGE Attention of Attention Mechanism

paper

Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks

Paper link

paper:Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks

model structure

insert image description here

The main content of the paper

Convolutional Neural Networks (CNNs) generate feature representations of complex objects by collecting semantic sub-features at different levels and parts. These sub-features can usually be distributed in the feature vectors of each layer in a grouped form, representing various semantic entities. However, the activations of these sub-features are often spatially affected by similar patterns and noisy backgrounds, leading to erroneous localization and recognition. This paper proposes a Spatial Group Enhancement (SGE) module that can adjust the importance of each sub-feature by generating an attention factor for each spatial location in each semantic group, so that each individual group can autonomously enhance It learns the expression and suppresses possible noise. Attention factors are only guided by the similarity between global and local feature descriptors within each group, so the design of the SGE module is very lightweight with almost no extra parameters and calculations.

import numpy as np
import torch
from torch import nn
from torch.nn import init



class SpatialGroupEnhance(nn.Module):

    def __init__(self, groups):
        super().__init__()
        self.groups=groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight=nn.Parameter(torch.zeros(1,groups,1,1))
        self.bias=nn.Parameter(torch.zeros(1,groups,1,1))
        self.sig=nn.Sigmoid()
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h,w=x.shape
        x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w
        xn=x*self.avg_pool(x) #bs*g,dim//g,h,w
        xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w
        t=xn.view(b*self.groups,-1) #bs*g,h*w

        t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w
        std=t.std(dim=1,keepdim=True)+1e-5
        t=t/std #bs*g,h*w
        t=t.view(b,self.groups,h,w) #bs,g,h*w
        
        t=t*self.weight+self.bias #bs,g,h*w
        t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w
        x=x*self.sig(t)
        x=x.view(b,c,h,w)

        return x 


if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    sge = SpatialGroupEnhance(groups=8)
    output=sge(input)
    print(output.shape)

code analysis

This module can enhance the expressive power of feature maps in CNN and improve its performance.
In this code block, two variables b and c are defined, representing the batch size and the number of channels of the input x, respectively.

Then, the reshape operation is performed to divide each channel into self.groups groups. In the first line, x becomes a tensor of size (b*groups,dim//groups,h,w) after reshape.

Then, the feature map divided into groups passes xn = x 1 HW ∑ x xn=x\frac{1}{HW}\sum xxn=xHW1The x method generates a weighted feature map, where H*W is the number of pixels in the feature map, and the avg_pool operation ensures that the shape of the output feature map is consistent.

Next, execute xn xnWeighting operation for x n .

After that, the weighted tensor ttt is standardized, that is, the mean is subtracted from each element and divided by the variance.

At this stage, t is treated as the global mean (rolling-mean) and standard deviation (rolling-std) of the group. Each value within a group is subtracted from their group's rolling-mean, and divided by the group's rolling-std. The result of normalization is stored in tensor t.

And, the calculation of the activation function is performed, that is, t * self.weight + self.bias, where self.weight and self.bias are learnable parameters.

Then, the x * the sigmoid of t operation is performed and the results are grouped again.

Multiply the normalized result by the original input "x" (the third line of code) and reshape the result into the final output.

Finally, test the model with random tensors from input and print out its final output shape.

Guess you like

Origin blog.csdn.net/qq_38915354/article/details/130552516