Attention mechanism - Convolutional Block Attention Module (CBAM)

Convolutional Block Attention Module (CBAM) : CBAM is a combination model that combines channel attention and spatial attention to improve the expressiveness of the model.

The CBAM module includes two attention sub-modules: channel attention module and spatial attention module. The channel attention module is used to calculate the importance of each channel to better distinguish features between different channels. The spatial attention module is used to calculate the spatial importance of each pixel in order to better capture the spatial structure in the image.

The channel attention module performs maximum pooling and average pooling on the channel dimension of the input feature map, and then inputs the two pooling results into a fully connected layer, and finally outputs a channel attention weight vector. This vector is used to weight each channel in the input feature map to better distinguish features from different channels.

The spatial attention module performs average pooling and maximum pooling on the channel dimension of the input feature map, and then inputs the two pooling results into a fully connected layer, and finally outputs a spatial attention weight tensor. This tensor is used to spatially weight each pixel to better capture the spatial structure in the image.

The overall structure of the CBAM module is shown in the figure below:

 In the figure, green boxes represent channel attention modules, and orange boxes represent spatial attention modules. By concatenating these two modules, a complete CBAM module can be obtained for insertion into a convolutional neural network to improve model performance.

Implement CBAM with pytorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)


class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_att = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_att(x) * x
        out = self.spatial_att(out) * out
        return out

The above code implements two sub-modules of the CBAM module: the channel attention module (ChannelAttention) and the spatial attention module (SpatialAttention), and the entire CBAM module (CBAM).

The channel attention module performs maximum pooling and average pooling on the channel dimension of the input feature map, and then inputs the two pooling results into a fully connected layer, and finally outputs a channel attention weight vector. The spatial attention module performs average pooling and maximum pooling on the channel dimension of the input feature map, and then inputs the two pooling results into a fully connected layer, and finally outputs a spatial attention weight tensor. The CBAM module connects these two sub-modules in series for insertion into the convolutional neural network to improve model performance.

Use the CBAM module in your model:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.cbam1 = CBAM(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.cbam2 = CBAM(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.cbam3 = CBAM(256)
        
        self.fc = nn.Linear(256 * 8 * 8, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.cbam1(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x))
        x = self.cbam2(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv3(x))
        x = self.cbam3(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

This is a simple convolutional neural network with three convolutional layers and three CBAM modules, and finally a fully connected layer to convert the feature map into a predicted label. In the forward pass process, the input feature map is extracted and enhanced through the convolutional layer and the CBAM module, then the spatial size of the feature map is reduced through the maximum pooling layer, and finally the predicted label is output through the fully connected layer.

Guess you like

Origin blog.csdn.net/weixin_50752408/article/details/129585880