Attention Mechanisms in Computer Vision

The attention mechanism we refer to here generally refers to soft attention.

There is an Attention mechanism to weight each feature map channel. You can refer to SENet of Tensorflow's Image Operation (4) , here we mainly discuss Self-Attention.

The figure above is the basic structure of the Self-Attention mechanism. The leftmost feature maps come from the downsampling output of the convolutional layer, which is usually 1/8 of the size of the original input image. Then convolve the feature map through three 1*1 convolution kernels, which is a cascade operation. The first f and the second g respectively change the number of channels of the feature map to 1/8 of the original, while the third h keeps the original number of channels unchanged.

Here, the output of f is transposed and the output of g is dot-multiplied in order to calculate the similarity to obtain the weight, and then perform softmax normalization. The normalized weight and the corresponding h are weighted and summed to get the final attention.

Pytorch implementation

import torch.nn as nn
import torch

class Self_Attn(nn.Module):
    """ Self attention Layer"""

def __init__(self, in_dim):
        super(Self_Attn, self).__init__()
        self.chanel_in = in_dim    

        self.f = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.g = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.h = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = torch.zeros(1, requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
        inputs :
            x : input feature maps (B X C X W X H)
        returns :
            out : self attention value + input feature
            attention: B X N X N (N is Width*Height)
        """
        m_batchsize, C, width, height = x.size()
        f1 = self.f(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B X CX(N)
        g1 = self.g(x).view(m_batchsize, -1, width * height)  # B X C x (*W*H)
        energy = torch.bmm(f1, g1)  # transpose check
        attention = self.softmax(energy)  # BX (N) X (N)
        h1 = self.h(x).view(m_batchsize, -1, width * height)  # B X C X N

        out = torch.bmm(h1, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)
        out = self.gamma * out + x

        return out, attention

if __name__ == '__main__':

    a = torch.rand(1, 512, 64, 64)
    self_atten = Self_Attn(512)
    out, atten = self_atten(a)
    print(out)
    print(atten)

Tensorflow implementation

import tensorflow as tf
from tensorflow.keras import models, layers
import numpy as np
import Scale

class Self_Attn(models.Model):

    def __init__(self, in_dim):
        super(Self_Attn, self).__init__()
        self.f = layers.Conv2D(in_dim // 8, (1, 1))
        self.g = layers.Conv2D(in_dim // 8, (1, 1))
        self.h = layers.Conv2D(in_dim, (1, 1))

    def call(self, x):
        batchsize, width, height, channel = x.shape
        f1 = self.f(x)
        f1 = tf.reshape(f1, (batchsize, -1, width * height))
        f1 = tf.transpose(f1, (0, 2, 1))
        g1 = self.g(x)
        g1 = tf.reshape(g1, (batchsize, -1, width * height))
        h1 = self.h(x)
        h1 = tf.reshape(h1, (batchsize, -1, width * height))
        energy = tf.matmul(f1, g1)
        atten = tf.nn.softmax(energy, axis=-1)
        atten = tf.transpose(atten, (0, 2, 1))
        out = tf.matmul(h1, atten)
        out = tf.reshape(out, (batchsize, width, height, channel))
        out = Scale.Scale()(out)
        out = out + x
        return out, atten

if __name__ == '__main__':

    a = tf.constant(np.random.rand(1, 64, 64, 512), dtype=tf.float32)
    self_atten = Self_Attn(512)
    out, atten = self_atten(a)
    print(out)
    print(atten)

The Scale code is as follows

from tensorflow.keras import layers

class Scale(layers.Layer):
def __init__(self, **kwargs):
        super(Scale, self).__init__(**kwargs)
    

    def build(self):
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zero', trainable=True)

    def call(self, x, mask=None):
        return self.gamma * x
{{o.name}}
{{m.name}}

Guess you like

Origin http://10.200.1.11:23101/article/api/json?id=324105046&siteId=291194637