2021 《Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks》 Pytorch实现

import torch
from torch import nn
from torch.nn import init


# External Attention
# 外部注意力
# 方法出处 2021 arxiv 《Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks》
class ExternalAttention(nn.Module):
    # 网络层的初始化
    def __init__(self, d_model, S=64):
        # 所有继承于nn.Module的模型都要写这句话
        super(ExternalAttention, self).__init__()
        # 外部记忆单元1
        self.mk = nn.Linear(d_model, S, bias=False)
        # 外部记忆单元2
        self.mv = nn.Linear(S, d_model, bias=False)
        # softmax层
        self.softmax = nn.Softmax(dim=1)
        # 网络层权重初始化
        self.init_weights()

    def init_weights(self):
        # 遍历当前模型所有的层
        for m in self.modules():
            # 如果是卷积层
            if isinstance(m, nn.Conv2d):
                # kaiming初始化
                init.kaiming_normal_(m.weight, mode='fan_out')
                # 偏置初始化为0
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            # 如果是正则化层
            elif isinstance(m, nn.BatchNorm2d):
                # 权重为1,偏置为0
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            # 如果是线性层
            elif isinstance(m, nn.Linear):
                # 正态分布初始化
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries):
        attn = self.mk(queries)  # bs,n,S
        # 沿着第一个维度n进行softmax
        # 对于每一个切片矩阵n*s
        # 的每一列进行softmax
        # 相当于捕获不同样本的相似性
        attn = self.softmax(attn)  # bs,n,S
        # 这步相当于正则化
        # toch.sum(attn,dim=2,keepdim=True)
        # 是对于attn沿着第二维度相加
        # 输出结果维度是[bs,n,1]
        # 通过广播机制去除attn
        # 相当于对于attn的每一行进行softmax
        attn = attn / torch.sum(attn, dim=2, keepdim=True)  # bs,n,S
        out = self.mv(attn)  # bs,n,d_model

        return out


if __name__ == '__main__':
    input = torch.randn(50, 49, 512)
    ea = ExternalAttention(d_model=512, S=8)
    output = ea(input)
    print(output.shape)

猜你喜欢

转载自blog.csdn.net/Talantfuck/article/details/124557430
今日推荐