CBAM attention mechanism - pytorch implementation

Paper Portal: CBAM: Convolutional Block Attention Module

Purpose of CBAM:

Add attention mechanism to the network .

The structure of CBAM:

Channel attention mechanism (Channel attention module): The input features are respectively subjected to global maximum pooling and global average pooling, and the pooling results are passed through a weight-sharing MLP, and the obtained weights are added, and finally the channel attention is obtained through the sigmoid activation function Force weight M c M_cMc;
②Spatial attention module (Spatial attention module): Input features are subjected to maximum pooling and average pooling in the channel dimension to obtain (2, H, W) feature layers, and after 7x7 convolution, output single-channel features Layer, and finally through the sigmoid activation function to get the spatial attention weight M s M_sMs;
Channel Attention Mechanism and Spatial Attention Mechanism
③The two are connected in series : the author builds the two in series, and the channel attention module is in front, and the spatial attention module is in the back.
CBAM structure
After experiments, the author found that the effect of building in series is better than building in parallel, and the effect of channel attention first is better than spatial attention first.
structural reasons

import torch
import torch.nn as nn


class ChannelAttention(nn.Module):  # Channel attention module
    def __init__(self, channels, ratio=16):  # r: reduction ratio=16
        super(ChannelAttention, self).__init__()

        hidden_channels = channels // ratio
        self.avgpool = nn.AdaptiveAvgPool2d(1)  # global avg pool
        self.maxpool = nn.AdaptiveMaxPool2d(1)  # global max pool
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, hidden_channels, 1, 1, 0, bias=False),  # 1x1conv代替全连接,根据原文公式没有偏置项
            nn.ReLU(inplace=True),  # relu
            nn.Conv2d(hidden_channels, channels, 1, 1, 0, bias=False)  # 1x1conv代替全连接,根据原文公式没有偏置项
        )
        self.sigmoid = nn.Sigmoid()  # sigmoid

    def forward(self, x):
        x_avg = self.avgpool(x)
        x_max = self.maxpool(x)
        return self.sigmoid(
            self.mlp(x_avg) + self.mlp(x_max)
        )  # Mc(F) = σ(MLP(AvgPool(F))+MLP(MaxPool(F)))= σ(W1(W0(Fcavg))+W1(W0(Fcmax))),对应原文公式(2)


class SpatialAttention(nn.Module):  # Spatial attention module
    def __init__(self):
        super(SpatialAttention, self).__init__()

        self.conv = nn.Conv2d(2, 1, 7, 1, 3, bias=False)  # 7x7conv
        self.sigmoid = nn.Sigmoid()  # sigmoid

    def forward(self, x):
        x_avg = torch.mean(x, dim=1, keepdim=True)  # 在通道维度上进行avgpool,(B,C,H,W)->(B,1,H,W)
        x_max = torch.max(x, dim=1, keepdim=True)[0]  # 在通道维度上进行maxpool,(B,C,H,W)->(B,1,H,W)
        return self.sigmoid(
            self.conv(torch.cat([x_avg, x_max],dim=1))
        )  # Ms(F) = σ(f7×7([AvgP ool(F);MaxPool(F)])) = σ(f7×7([Fsavg;Fsmax])),对应原文公式(3)


class CBAM(nn.Module):  # Convolutional Block Attention Module
    def __init__(self, channels, ratio=16):
        super(CBAM, self).__init__()

        self.channel_attention = ChannelAttention(channels, ratio)  # Channel attention module
        self.spatial_attention = SpatialAttention()  # Spatial attention module

    def forward(self, x):
        f1 = self.channel_attention(x) * x  # F0 = Mc(F)⊗F,对应原文公式(1)
        f2 = self.spatial_attention(f1) * f1  # F00 = Ms(F0)⊗F0,对应原文公式(1)
        return f2

Guess you like

Origin blog.csdn.net/Peach_____/article/details/128723630