CoTAmecanismo de atenção
A rede CoTAttention é um modelo de rede neural para tarefas de Visual Question Answering (VQA) em cenários multimodais. É uma melhoria no Mecanismo de Atenção clássico, que pode alocar atenção de forma adaptativa para diferentes entradas visuais e de linguagem, de modo a concluir melhor a tarefa VQA.
O "CoT" na rede CoTAttention significa "Cross-modal Transformer", que é um transformador cross-modal. Nesta rede, as entradas visuais e linguísticas são codificadas separadamente em um conjunto de vetores de recursos, que são então interagidos e integrados por meio de um módulo Transformer cross-modal. Neste módulo Transformer cross-modal, o mecanismo de Co-Atenção é usado para calcular a atenção interativa entre os recursos visuais e de linguagem, de modo a obter uma melhor troca e integração de informações. Na tarefa VQA, onde a visão computacional e o processamento de linguagem natural estão intimamente combinados, a rede CoTAtention alcançou bons resultados.
Endereço do artigo: https://arxiv.org/pdf/2107.12292.pdf
código mostra como abaixo:
import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F
class CoTAttention(nn.Module):
def __init__(self, dim=512, kernel_size=3):
super().__init__()
self.dim = dim
self.kernel_size = kernel_size
self.key_embed = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
nn.BatchNorm2d(dim),
nn.ReLU()
)
self.value_embed = nn.Sequential(
nn.Conv2d(dim, dim, 1, bias=False),
nn.BatchNorm2d(dim)
)
factor = 4
self.attention_embed = nn.Sequential(
nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
nn.BatchNorm2d(2 * dim // factor),
nn.ReLU(),
nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
)
def forward(self, x):
bs, c, h, w = x.shape
k1 = self.key_embed(x) # bs,c,h,w
v = self.value_embed(x).view(bs, c, -1) # bs,c,h,w
y = torch.cat([k1, x], dim=1) # bs,2c,h,w
att = self.attention_embed(y) # bs,c*k*k,h,w
att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
att = att.mean(2, keepdim=False).view(bs, c, -1) # bs,c,h*w
k2 = F.softmax(att, dim=-1) * v
k2 = k2.view(bs, c, h, w)
return k1 + k2
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
cot = CoTAttention(dim=512, kernel_size=3)
output = cot(input)
print(output.shape)