自然语言处理(十八):Transformer多头自注意力机制

自然语言处理笔记总目录


Transformer介绍

注意力机制

def attention(query, key, value, mask=None, dropout=None):
    # 词嵌入的维度
    d_k = query.size(-1)
    # 首先得到注意力得分张量scores
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 判断是否使用掩码
    if mask is not None:
        # 如果掩码张量某处为0,那么就用 -inf 代替scores某处对应数值
        scores = scores.masked_fill(mask==0, -1e9)
        
    p_attn = F.softmax(scores, dim=-1)
    
    # 判断是否使用dropout
    if dropout is not None:
        p_attn = dropout(p_attn)
        
    attn = torch.matmul(p_attn, value)
    return attn, p_attn
# 造一些假数据
q = k = v = out_pe	# 位置嵌入层的输出
mask = torch.zeros(8,4,4)
attn, p_attn = attention(q, k, v, mask = mask)

print(attn)
print(attn.shape)

print(p_attn)
print(p_attn.shape)
tensor([[[-3.4330, 17.0039,  7.6916,  ..., -1.1810, 13.4835,  6.5877],
         [-3.4330, 17.0039,  7.6916,  ..., -1.1810, 13.4835,  6.5877],
         [-3.4330, 17.0039,  7.6916,  ..., -1.1810, 13.4835,  6.5877],
         [-3.4330, 17.0039,  7.6916,  ..., -1.1810, 13.4835,  6.5877]],

        [[-5.7291,  9.6484, -3.8307,  ..., -6.4001,  2.8157, -2.3674],
         [-5.7291,  9.6484, -3.8307,  ..., -6.4001,  2.8157, -2.3674],
         [-5.7291,  9.6484, -3.8307,  ..., -6.4001,  2.8157, -2.3674],
         [-5.7291,  9.6484, -3.8307,  ..., -6.4001,  2.8157, -2.3674]]],
       grad_fn=<UnsafeViewBackward>)
torch.Size([2, 4, 512])

tensor([[[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]]], grad_fn=<SoftmaxBackward>)
torch.Size([2, 4, 4])

mask 演示

有两种mask,详情见Transformer介绍,编码器与解码器的代码是一样的,只是传入的mask矩阵是不一样的

input = torch.randn(2,5,5)
print(input)
tensor([[[ 1.6237, -0.4926,  0.2511, -0.0766,  1.2336],
         [ 0.6095,  0.3129, -1.3681,  1.4665,  2.7871],
         [ 0.5896, -0.3104, -1.5489,  0.7066, -0.5313],
         [-0.0101, -0.4480,  1.3695, -0.5241, -1.6751],
         [-1.0940, -1.4501,  0.1156,  0.7294, -0.7895]],

        [[-0.3445, -1.2161, -0.4054, -0.7804, -0.5310],
         [-1.7213, -0.9197, -1.7822, -0.0254,  1.1709],
         [-0.2137, -1.0617, -0.8737,  0.6546, -1.8320],
         [-1.9422,  0.4181, -0.5073,  0.2615,  0.0958],
         [ 0.9671, -0.9516, -0.0827, -0.1647, -0.7664]]])
mask = torch.tensor([[[1,1,1,1,1], 
                     [1,1,1,0,1], 
                     [1,1,0,0,1], 
                     [1,0,0,0,0], 
                     [0,0,0,0,0]]])
mask.shape	# torch.Size([1, 5, 5])
print(mask)

tensor([[[1, 1, 1, 1, 1],
         [1, 1, 1, 0, 1],
         [1, 1, 0, 0, 1],
         [1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]])
input.masked_fill(mask==0,-1e9)

tensor([[[ 1.6237e+00, -4.9256e-01,  2.5107e-01, -7.6616e-02,  1.2336e+00],
         [ 6.0951e-01,  3.1294e-01, -1.3681e+00, -1.0000e+09,  2.7871e+00],
         [ 5.8960e-01, -3.1039e-01, -1.0000e+09, -1.0000e+09, -5.3130e-01],
         [-1.0084e-02, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[-3.4447e-01, -1.2161e+00, -4.0536e-01, -7.8037e-01, -5.3098e-01],
         [-1.7213e+00, -9.1969e-01, -1.7822e+00, -1.0000e+09,  1.1709e+00],
         [-2.1375e-01, -1.0617e+00, -1.0000e+09, -1.0000e+09, -1.8320e+00],
         [-1.9422e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]]])

多头注意力机制

# 克隆函数:生成相同的网络层,N代表克隆多少份,存放在nn.ModuleList类型的列表中
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
# 多头注意力机制
class MultiHeadedAttention(nn.Module):
    def __init__(self, head, embedding_dim, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        
        # 首先判断head能否被词嵌入维度整除
        assert embedding_dim % head == 0
        
        # 头数
        self.head = head
        # 词嵌入维度
        self.embedding_dim = embedding_dim
        # 获得分割后的维度
        self.d_k = self.embedding_dim // self.head
        
        
        # 获得四个embedding_dim x embedding_dim线性层
        # 分别是Q、K、V,以及最后的拼接矩阵
        self.linears = clones(nn.Linear(self.embedding_dim, self.embedding_dim), 4)
        
        # 初始化最后得到的注意力张量
        self.attn = None
        
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # 扩展维度
            mask = mask.unsqueeze(0)
            
        batch_size = query.size(0)
        
        # 对QKV进行多头分割
        # 同时,将代表句子长度的维度与头数维度互换,使之与词向量维度相邻
        query, key, value = \
            [model(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2)   
             for model, x in zip(self.linears, (query, key, value))]
        
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        # 重塑为与输入的形状相同
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.embedding_dim)
        
        # 使用线性层列表的最后一个线性层处理输出,结构如上面的介绍图
        return self.linears[-1](x)
head = 8
embedding_dim = 512
dropout = 0.2

q = k = v = out_pe
mask = torch.zeros(8,4,4)
mha = MultiHeadedAttention(head, embedding_dim, dropout)
out_mha = mha(q, k, v, mask)

print(out_mha)
print(out_mha.shape)
tensor([[[ -2.8880,  -8.5742,   1.3168,  ...,  -9.3297, -10.9577,  -8.8687],
         [  1.1822,  -7.6474,  -1.2463,  ...,  -8.7212,  -8.4979,  -4.2104],
         [ -0.4473,  -8.8571,   2.7700,  ...,  -5.2058, -11.9773,  -8.1035],
         [  0.1940,  -7.0324,  -0.5626,  ...,  -8.7001,  -5.0893,  -7.9655]],

        [[  6.0457,  -4.1362,  -8.6744,  ...,  -4.0828,   2.2592,   5.6931],
         [  3.2074,  -2.5000,  -6.2406,  ...,  -5.8451,  10.2600,   2.8066],
         [  4.5255,  -1.2257,  -7.2707,  ...,  -8.0517,   5.8174,  -2.4052],
         [  3.9835,  -0.6454,  -5.9249,  ...,  -7.9849,   5.4700,   2.4230]]],
       grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])

contiguous演示

使用transpose之后,必须使用contiguous才能进行view
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_45707277/article/details/122635202
今日推荐