pytorch コードはアテンションメカニズムの SimAM を実装します

SimAM アテンション メカニズム

SimAM (類似性ベースのアテンション メカニズム) は、シーケンス データの類似性を計算するためのアテンション メカニズムです。
論文アドレス: http://proceedings.mlr.press/v139/yang21o/yang21o.pdf
SimAM では、クエリ シーケンス Q とキーと値のペア シーケンス K が与えられると、アテンション メカニズムはクエリ シーケンスとキーの間の関係を計算します。シーケンスの類似性を利用して注目の重みを決定します。具体的には、SimAM はコサイン類似度を使用して、クエリ シーケンスとキー シーケンス間の類似性を測定します。
SimAM の計算プロセスは次のとおりです。
(1) クエリ シーケンス Q の各要素 q について、キー シーケンス K の各要素 k とのコサイン類似度を計算します。
(2) 各クエリ配列要素 q について、キー配列要素 k との類似度を正規化して注目重みを取得します。
(3) 値シーケンス V は、アテンション重みを使用して重み付けされ、合計され、最終的なアテンション表現が得られます。
構造図

コードは以下のように表示されます:

import torch
import torch.nn as nn


class SimAM(torch.nn.Module):
    def __init__(self, e_lambda=1e-4):
        super(SimAM, self).__init__()

        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s

    @staticmethod
    def get_module_name():
        return "simam"

    def forward(self, x):
        b, c, h, w = x.size()

        n = w * h - 1

        x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5

        return x * self.activaton(y)


if __name__ == '__main__':
    input = torch.randn(3, 64, 7, 7)
    model = SimAM()
    outputs = model(input)
    print(outputs.shape)

おすすめ

転載: blog.csdn.net/DM_zx/article/details/132302606