import torch
from torch import nn
from torch.nn import init
# External Attention
# 外部注意力
# 方法出处 2021 arxiv 《Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks》
class ExternalAttention(nn.Module):
# 网络层的初始化
def __init__(self, d_model, S=64):
# 所有继承于nn.Module的模型都要写这句话
super(ExternalAttention, self).__init__()
# 外部记忆单元1
self.mk = nn.Linear(d_model, S, bias=False)
# 外部记忆单元2
self.mv = nn.Linear(S, d_model, bias=False)
# softmax层
self.softmax = nn.Softmax(dim=1)
# 网络层权重初始化
self.init_weights()
def init_weights(self):
# 遍历当前模型所有的层
for m in self.modules():
# 如果是卷积层
if isinstance(m, nn.Conv2d):
# kaiming初始化
init.kaiming_normal_(m.weight, mode='fan_out')
# 偏置初始化为0
if m.bias is not None:
init.constant_(m.bias, 0)
# 如果是正则化层
elif isinstance(m, nn.BatchNorm2d):
# 权重为1,偏置为0
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
# 如果是线性层
elif isinstance(m, nn.Linear):
# 正态分布初始化
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, queries):
attn = self.mk(queries) # bs,n,S
# 沿着第一个维度n进行softmax
# 对于每一个切片矩阵n*s
# 的每一列进行softmax
# 相当于捕获不同样本的相似性
attn = self.softmax(attn) # bs,n,S
# 这步相当于正则化
# toch.sum(attn,dim=2,keepdim=True)
# 是对于attn沿着第二维度相加
# 输出结果维度是[bs,n,1]
# 通过广播机制去除attn
# 相当于对于attn的每一行进行softmax
attn = attn / torch.sum(attn, dim=2, keepdim=True) # bs,n,S
out = self.mv(attn) # bs,n,d_model
return out
if __name__ == '__main__':
input = torch.randn(50, 49, 512)
ea = ExternalAttention(d_model=512, S=8)
output = ea(input)
print(output.shape)
2021 《Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks》 Pytorch实现
猜你喜欢
转载自blog.csdn.net/Talantfuck/article/details/124557430
今日推荐
周排行