The attention mechanism we refer to here generally refers to soft attention.
There is an Attention mechanism to weight each feature map channel. You can refer to SENet of Tensorflow's Image Operation (4) , here we mainly discuss Self-Attention.
The figure above is the basic structure of the Self-Attention mechanism. The leftmost feature maps come from the downsampling output of the convolutional layer, which is usually 1/8 of the size of the original input image. Then convolve the feature map through three 1*1 convolution kernels, which is a cascade operation. The first f and the second g respectively change the number of channels of the feature map to 1/8 of the original, while the third h keeps the original number of channels unchanged.
Here, the output of f is transposed and the output of g is dot-multiplied in order to calculate the similarity to obtain the weight, and then perform softmax normalization. The normalized weight and the corresponding h are weighted and summed to get the final attention.
Pytorch implementation
import torch.nn as nn import torch class Self_Attn(nn.Module): """ Self attention Layer""" def __init__(self, in_dim): super(Self_Attn, self).__init__() self.chanel_in = in_dim self.f = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.g = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.h = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.gamma = torch.zeros(1, requires_grad=True) self.softmax = nn.Softmax(dim=-1) def forward(self, x): """ inputs : x : input feature maps (B X C X W X H) returns : out : self attention value + input feature attention: B X N X N (N is Width*Height) """ m_batchsize, C, width, height = x.size() f1 = self.f(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) g1 = self.g(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) energy = torch.bmm(f1, g1) # transpose check attention = self.softmax(energy) # BX (N) X (N) h1 = self.h(x).view(m_batchsize, -1, width * height) # B X C X N out = torch.bmm(h1, attention.permute(0, 2, 1)) out = out.view(m_batchsize, C, width, height) out = self.gamma * out + x return out, attention if __name__ == '__main__': a = torch.rand(1, 512, 64, 64) self_atten = Self_Attn(512) out, atten = self_atten(a) print(out) print(atten)
Tensorflow implementation
import tensorflow as tf from tensorflow.keras import models, layers import numpy as np import Scale class Self_Attn(models.Model): def __init__(self, in_dim): super(Self_Attn, self).__init__() self.f = layers.Conv2D(in_dim // 8, (1, 1)) self.g = layers.Conv2D(in_dim // 8, (1, 1)) self.h = layers.Conv2D(in_dim, (1, 1)) def call(self, x): batchsize, width, height, channel = x.shape f1 = self.f(x) f1 = tf.reshape(f1, (batchsize, -1, width * height)) f1 = tf.transpose(f1, (0, 2, 1)) g1 = self.g(x) g1 = tf.reshape(g1, (batchsize, -1, width * height)) h1 = self.h(x) h1 = tf.reshape(h1, (batchsize, -1, width * height)) energy = tf.matmul(f1, g1) atten = tf.nn.softmax(energy, axis=-1) atten = tf.transpose(atten, (0, 2, 1)) out = tf.matmul(h1, atten) out = tf.reshape(out, (batchsize, width, height, channel)) out = Scale.Scale()(out) out = out + x return out, atten if __name__ == '__main__': a = tf.constant(np.random.rand(1, 64, 64, 512), dtype=tf.float32) self_atten = Self_Attn(512) out, atten = self_atten(a) print(out) print(atten)
The Scale code is as follows
from tensorflow.keras import layers class Scale(layers.Layer): def __init__(self, **kwargs): super(Scale, self).__init__(**kwargs) def build(self): self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zero', trainable=True) def call(self, x, mask=None): return self.gamma * x