pytorch コードは、アテンション メカニズムの CoTAttendant を実装します。

CoTAttention アテンション メカニズム

CoTtention ネットワークは、マルチモーダル シナリオにおけるビジュアル質問応答 (VQA) タスク用のニューラル ネットワーク モデルです。これは、VQA タスクをより適切に完了するために、さまざまな視覚および言語入力に適応的に注意を割り当てることができる、古典的な注意メカニズム (注意メカニズム) の改良版です。

CoTtentionネットワークの「CoT」は「Cross-modal Transformer」の略で、クロスモーダルトランスフォーマーです。このネットワークでは、視覚入力と言語入力が一連の特徴ベクトルに個別にエンコードされ、クロスモーダル Transformer モジュールを通じて相互作用して統合されます。このクロスモーダル Transformer モジュールでは、Co-Attend メカニズムを使用して視覚機能と言語機能の間のインタラクティブな注意を計算し、より良い情報交換と統合を実現します。コンピューター ビジョンと自然言語処理が密接に組み合わされる VQA タスクでは、CoTAtention ネットワークが良好な結果を達成しました。

論文アドレス: https://arxiv.org/pdf/2107.12292.pdf

構造図

コードは以下のように表示されます:

import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F

class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2


if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    cot = CoTAttention(dim=512, kernel_size=3)
    output = cot(input)
    print(output.shape)

おすすめ

転載: blog.csdn.net/DM_zx/article/details/132321188