Paper Portal: CBAM: Convolutional Block Attention Module
Purpose of CBAM:
Add attention mechanism to the network .
The structure of CBAM:
① Channel attention mechanism (Channel attention module): The input features are respectively subjected to global maximum pooling and global average pooling, and the pooling results are passed through a weight-sharing MLP, and the obtained weights are added, and finally the channel attention is obtained through the sigmoid activation function Force weight M c M_cMc;
②Spatial attention module (Spatial attention module): Input features are subjected to maximum pooling and average pooling in the channel dimension to obtain (2, H, W) feature layers, and after 7x7 convolution, output single-channel features Layer, and finally through the sigmoid activation function to get the spatial attention weight M s M_sMs;
③The two are connected in series : the author builds the two in series, and the channel attention module is in front, and the spatial attention module is in the back.
After experiments, the author found that the effect of building in series is better than building in parallel, and the effect of channel attention first is better than spatial attention first.
import torch
import torch.nn as nn
class ChannelAttention(nn.Module): # Channel attention module
def __init__(self, channels, ratio=16): # r: reduction ratio=16
super(ChannelAttention, self).__init__()
hidden_channels = channels // ratio
self.avgpool = nn.AdaptiveAvgPool2d(1) # global avg pool
self.maxpool = nn.AdaptiveMaxPool2d(1) # global max pool
self.mlp = nn.Sequential(
nn.Conv2d(channels, hidden_channels, 1, 1, 0, bias=False), # 1x1conv代替全连接,根据原文公式没有偏置项
nn.ReLU(inplace=True), # relu
nn.Conv2d(hidden_channels, channels, 1, 1, 0, bias=False) # 1x1conv代替全连接,根据原文公式没有偏置项
)
self.sigmoid = nn.Sigmoid() # sigmoid
def forward(self, x):
x_avg = self.avgpool(x)
x_max = self.maxpool(x)
return self.sigmoid(
self.mlp(x_avg) + self.mlp(x_max)
) # Mc(F) = σ(MLP(AvgPool(F))+MLP(MaxPool(F)))= σ(W1(W0(Fcavg))+W1(W0(Fcmax))),对应原文公式(2)
class SpatialAttention(nn.Module): # Spatial attention module
def __init__(self):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, 7, 1, 3, bias=False) # 7x7conv
self.sigmoid = nn.Sigmoid() # sigmoid
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True) # 在通道维度上进行avgpool,(B,C,H,W)->(B,1,H,W)
x_max = torch.max(x, dim=1, keepdim=True)[0] # 在通道维度上进行maxpool,(B,C,H,W)->(B,1,H,W)
return self.sigmoid(
self.conv(torch.cat([x_avg, x_max],dim=1))
) # Ms(F) = σ(f7×7([AvgP ool(F);MaxPool(F)])) = σ(f7×7([Fsavg;Fsmax])),对应原文公式(3)
class CBAM(nn.Module): # Convolutional Block Attention Module
def __init__(self, channels, ratio=16):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(channels, ratio) # Channel attention module
self.spatial_attention = SpatialAttention() # Spatial attention module
def forward(self, x):
f1 = self.channel_attention(x) * x # F0 = Mc(F)⊗F,对应原文公式(1)
f2 = self.spatial_attention(f1) * f1 # F00 = Ms(F0)⊗F0,对应原文公式(1)
return f2