002 self-attention self-attention

Table of contents

1. Environment

2. Self-attention principle

3. Complete code


1. Environment

The usage environment of this article is:

  • Windows10
  • Python 3.9.17
  • torch 1.13.1+cu117
  • torchvision 0.14.1+cu117

2. Self-attention principle

Self-Attention operation is a basic operation of the Transformer-based machine translation model and is frequently used in source language encoding
and target language generation. To model the dependence between any two words in the source language and the target language. Given
, the input representation {xi ∈ Rd} obtained by the superposition of word semantic embedding and its position encoding is determined. In order to achieve modeling of contextual semantic dependence, the self-attention mechanism is further introduced. Three elements: query qi (Query), key ki (Key), value vi (Value). In the process of encoding the representation of each word in the input sequence, these three elements are used to calculate the weight score corresponding to the context word. Intuitively, these weights reflect the degree of attention required to different parts of the context when encoding the representation of the current word. Specifically, as shown in the figure, each word representation xi in the input sequence is converted into its corresponding qi, ki, vi vector through three linear transformations WQ, WK, and WV.

In order to obtain the contextual information that needs to be paid attention to when encoding word xi, the matching score qi · k1, qi · k2, ..., qi · kt is obtained by doing a dot product of the query vector at position i and the key vectors at other positions. . In order to prevent excessively large matching scores from causing gradient explosion and poor convergence efficiency in the subsequent Softmax calculation process, these scores will be divided by the scaling factor √d to stabilize the optimization< a i=2>. After the scaled scores are normalized to probabilities by Softmax, they are multiplied with value vectors at other positions to aggregate the contextual information that you want to focus on and minimize the interference of irrelevant information. The above calculation process can be formally expressed as follows:

Among them, Q , K , and V respectively represent the matrix composed of the q, k, v vectors of different words in the input sequence, L represents the sequence length, and Z represents the output of the self-attention operation. In order to further enhance the ability of the self-attention mechanism to aggregate contextual information, a multi-head attention mechanism is proposed to focus on different aspects of the context. Specifically, the representation xi of each word in the context is mapped into different representation subspaces through multiple sets of linear {WQ*WK*WV}. The formula will be calculated separately in different subspaces and obtain different context-related word sequence representations {Zj}. Finally, the linear transformation WO is used to synthesize contextual representations in different subspaces and form the final output xi of the self-attention layer.

3. Complete code

import torch.nn as nn
import torch
import math
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads # 512 / 8 
        self.h = heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # self-attention公式
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1) # self-attention公式
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v) # self-attention公式
        return output
    def forward(self, q, k, v, mask=None):
        bs = q.size(0) # 进行线性操作划分为成 h 个头
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        # 矩阵转置
        k = k.transpose(1,2) 
        q = q.transpose(1,2) 
        v = v.transpose(1,2) # 计算 attention
        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)
        # 连接多个头并输入到最后的线性层
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output

# 准备q、k、v张量
d_model = 512
num_heads = 8
batch_size = 32
seq_len = 64

q = torch.randn(batch_size, seq_len, d_model) # 64 x 512
k = torch.randn(batch_size, seq_len, d_model) # 64 x 512
v = torch.randn(batch_size, seq_len, d_model) # 64 x 512

sa = MultiHeadAttention(heads = num_heads, d_model=d_model)
print(sa(q, k, v).shape) # torch.Size([32, 64, 512])
print('')

おすすめ

転載: blog.csdn.net/m0_72734364/article/details/134874321