多头注意力机制Multi-head-attention

import torch
import torch.nn as nn

class Multi_Head_Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.5, proj_drop=0.5):
        super(Multi_Head_Attention, self).__init__()
        self.dim = dim
        self.num_heads = num_heads

        self.qkv = nn.Linear(dim, dim * 3)

        self.head_dim = dim // num_heads  # d_k
        self.scale = self.head_dim ** -0.5  # 缩放因子根号d_k 防止梯度爆炸
        assert self.scale > 0
        self.softmax = nn.Softmax(dim=-1)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):  # x.shape = (B, N, C,如(64,2,512), N个token,每个token长度C
        B, N, C = x.shape  # C=dim代表每个token长度
        # 对qkv进行划分的同时,对每个q或k或v的每个token进行按头num_heads切分,同时进行维度调整,将qkv划分的维度3调整到最前方
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # qkv均为(B,num_heads,N,C//num_heads)高位矩阵
        # @或torch.matmul向量乘法可以实现最后两个维度的矩阵乘法(前面的维度是为矩阵的个数),transpose(-2, -1)交换最后两个维度
        attn = self.softmax((q @ k.transpose(-2, -1)) * self.scale)
        attn = self.attn_drop(attn)
        print(attn.shape)
        res = (attn @ v).transpose(-2, -1).reshape(B, N, C)
        res = self.proj_drop(self.proj(res))

        return res
x = torch.randn(size=(64, 10, 512))
att = Multi_Head_Attention(dim=512)
print(att(x).shape)

猜你喜欢

转载自blog.csdn.net/weixin_54338498/article/details/133689509