Detailed Explanation of Self-Attention and Multi-Head Attention Mechanism

  The self-attention mechanism is one of the attention mechanisms. The same as the traditional attention mechanism, the self-attention mechanism can pay more attention to the key information in the input. Self-attention can be regarded as a special case when the input data of multi-head attention is the same. So understanding the essence of self attention is actually understanding the multi-head attention structure.

One: Basic principles  

  For a multi-head attention, it can accept three sequences query, key, and value, where the length of the two sequences of key and value must be the same, and the length of the query sequence can be different from the length of key and value. The output sequence length of multi-head attention is the same as the input query sequence length. In Tutu, the length of query is recorded as Lq, and the length of key and value is recorded as Lk.

  Secondly, for the input sequences query, key, and value, their characteristic lengths (dimension dim of each element) can be different, and the dims of these three sequences are Dq, Dk, and Dv respectively. After these sequences are input into multi-head attention, the dim of the internal sequence can be different from Dq, Dk, and Dv. We call it the embedding dimension, denoted as De, and the output sequence dim is also De.

  Multi-head attention is composed of one or more parallel unit structures. We call each such unit structure a head (one head, in fact, it can also be called a layer). For convenience, Tutu named it temporarily This unit structure is one-head attention, and in a broad sense, when the head number is 1, it is also multi-head attention. The one-head attention structure is a combination of scaled dot-product attention and three weight matrices (or three parallel fully connected layers). The structure is shown in the figure below

Two: Scale Dot-Product Attention specific structure

  For the above figure, we regard each input sequence q, k, and v as a matrix of shape (Lq, Dq), (Lk, Dk), (Lk, Dv), that is, a matrix obtained by concatenating each element vector by row . The parameters of the Linear layer are (Dq, De), (Dk, De), (Dv, De), then through the fully connected layer, the output matrix shape is (Lq, De), (Lk, De), (Lv, De ), we let the matrix obtained through the fully connected layer be Q, K, V.

  The essence of the Linear layer is to multiply the weight matrix W with the input matrix (sometimes bias bias can also be added). In one-head attention, we make the weight matrix multiplied by Q, K, and V respectively, W^Q,W^K,W^Vthey The shape of is (Dq,De),(Dk,De),(Dv,De). The use of bias has no effect on the subsequent structure. In some deep learning frameworks, bias is added by default, but the original formula of "Attention Is All You Need" does not reflect bias, only W, so Tutu will explain the part later, The bias is not considered.

  After the input data is obtained through the Linear operation to obtain the Q, K, and V matrices, we really come to the Scale dot-product attention part.

  Scale dot-product attention can be expressed by a concise formula, where dk is our previous Dk:

Attention(Q,K,V)=softmax(\frac{Q.K^T}{\sqrt{d_k}}).V

  The output obtained by this formula is the output of onehead-attention, which is a matrix of shape (Lq, De), representing the output sequence of length Lq and dimension De. formula:

\frac{Q.K^T}{\sqrt{d_k}}

There is a name: attention weights, with a shape of (Lq, Lk), which can be roughly understood as the correlation between the corresponding elements of the q sequence and the k sequence, similar to the index that existed before you entered the keyword query on the web page key, which index keys to choose according to the correlation between the query and the key, and recommend the corresponding value according to the key.

  Speaking of this, the unit structure of multihead-attention has actually been introduced. But this process can be understood more deeply. The following figure shows the detailed structure of Scale dot-product attention when Lq and Lk are the same (generally Lq and Lk are equal, and it is likely that Q, K, and v come from the same sequence, which is self attention at this time. Bunny will talk about it later).

  The figure above shows a scale dot-product attention structure that receives Q, K, and V in the shape of (3, De). We disassemble Q, K, and V into sequences with a length of 3 and a dimension of De. Each time the inner product of q and each k is calculated to obtain a number a, these numbers pass softmax to obtain a new number a' (here softmax is the whole). The obtained a' is multiplied by the respective v vectors to obtain new vectors, and finally these new vectors are added to obtain a vector with a length of De, and then calculated in turn to obtain vectors b1 and b2, and these vectors b are combined into a matrix to obtain the final Output. For this process, if the sequences q, k, and v are represented by the previous matrices Q, K, and V as a whole, it is actually the formula given by Tutu, except that the formula is operated in parallel in the form of a matrix, so that the entire The calculation process is concise and faster.

   Of course, Lq is not necessarily equal to Lk in many cases. At this time, if the above figure is used to represent the process, it will be very messy. So Tutu uses the following figure to represent the scale dot-product attention process.

Three: The mask mask problem in Scale Dot-Product Attention

  Mask is dispensable in scale dot-product attention. In some cases, it will be better to use mask, and sometimes mask is not needed. The mask acts on the attention weight in the scale dot-product attention. As mentioned earlier, the shape of the attention weights is (Lq, Lk), and when using a mask, it is generally the case of self-attention. At this time, Lq=Lk, and the attention weights are a square matrix. The purpose of the mask is to make the upper triangle of the square matrix negative infinity (or a small negative number), and only keep the lower triangle, so that the upper triangle of the matrix tends to 0 after passing softmax. The purpose of this processing is to take into account the situation in practical applications. For example, in translation tasks, we hope to use only the previously read words each time when reading the sentence sequence, and it has nothing to do with the words and sentences that have not been read yet.

   In fact, the type of mask can not only cover up the upper triangle, but also make some columns or any positions on the right side of the matrix be -inf according to the actual situation to mask the information of these positions.

  For multi-head attention, if a mask is used, each head generally uses the same mask. At this time, the model is also called masked multihead-attention

import numpy as np
import torch

weight=torch.randint(0,5,size=(5,5))
mask=torch.tensor(np.array([[False,True,True,True,True],
                            [False,False,True,True,True],
                            [False,False,False,True,True],
                            [False,False,False,False,True],
                            [False,False,False,False,False]]))
masked_weight=weight.masked_fill(mask,-1000)
out=nn.Sigmoid()(masked_weight)
print(masked_weight)
print(out)
'''-------------------------------'''
>>>tensor([[    0, -1000, -1000, -1000, -1000],
        [    3,     4, -1000, -1000, -1000],
        [    3,     2,     0, -1000, -1000],
        [    4,     3,     1,     2, -1000],
        [    2,     3,     0,     2,     3]])
>>>tensor([[0.5000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9526, 0.9820, 0.0000, 0.0000, 0.0000],
        [0.9526, 0.8808, 0.5000, 0.0000, 0.0000],
        [0.9820, 0.9526, 0.7311, 0.8808, 0.0000],
        [0.8808, 0.9526, 0.5000, 0.8808, 0.9526]])

Four: Multi-Head Attention structure

  Multi-head attention consists of multiple one-head attention. We remember that a multi-head attention has n heads, and the weights of the i-th head are respectively W^Q_{i},W^K_{i},W^{V}_i, then:

head_i=Attention(q.W^Q_i,k.W^K_{i},v.W^{V}_i) \\ MultiHead(q,k,v)=Concat(head_1,head_2,...,head_n).W^O

This process is as follows: input q, k, and v matrices into each one-head attention respectively, each head output matrix is ​​spliced ​​according to the feature (dim) dimension to obtain a new matrix, and then multiplied by the matrix to obtain the output (in fact, it can also be W^Oa Fully connected layer Linear), and the output shape is still (Lq, De).

  Regarding the parameter W, there may actually be two situations,

The shape of (1) W^Q_i,W^K_i,W^V_iis: (Lq, De), (Lk, De), (Lk, De), then the shape of each head is (Lq, De), and the matrix shape obtained after splicing is (Lq, n×De) , W^Othe shape is: (n×De,De).

(2) W^Q_i,W^K_i,W^V_iThe shape is: (Lq, De/n), (Lk, De/n), (Lk, De/n) (at this time, it is necessary to ensure that the embedding dimension De can divide the number of heads n), then each head's The shape is (Lq, De/n), and the matrix shape (Lq, De) obtained after splicing W^Ois: (De, De).

Although the internal parameters of the two methods are different, the shape of the input and output data remains unchanged. MuitiheadAttention in Pytorch uses method (2).

Four: Understanding of self-attention

  Self-attention is the case where the three input sequences of multi-head attention are all derived from the same sequence. Let the input sequence be input. At this time, the three input sequences of q, k, and v are all input, so at this time Lq=Lk, Dq=Dk=Dv. Since all inputs are the same sequence, it is also easy to understand why it is called self-attention.

Five: The understanding and source of query, key, and value

  query, key, and value are query, key, and value respectively. They can be obtained from the same sequence, or they can be different sequences with practical significance. From the perspective of retrieval, query is the content to be retrieved, key is the index, and value is the value to be retrieved. The process of attention is to calculate the correlation between query and key, obtain the attention map, and use the attention map to obtain the features in the value value. In self-attention, query, key, and value are the same sequence. In general, query is a sequence, and key and value are the same sequence. More generally, query, key, and value are three different sequences.

Six: Application examples

1. Use Pytorch to build multi-head attention

class attention(nn.Module):
    def __init__(self,embed_dim,num_heads):
        '''
        :param embed_dim: 嵌入特征个数
        :param num_heads: scale dot-product attention层数
        '''
        super(attention, self).__init__()
        self.embed_dim=embed_dim
        self.num_heads=num_heads
        self.w_q=[nn.Linear(embed_dim,embed_dim) for i in range(num_heads)]
        self.w_k=[nn.Linear(embed_dim,embed_dim) for i in range(num_heads)]
        self.w_v=[nn.Linear(embed_dim,embed_dim) for i in range(num_heads)]
        self.w_o=nn.Linear(embed_dim*num_heads,embed_dim)
        self.softmax=nn.Softmax()
    def single_head(self,q,k,v,head_idx):
        '''scale dot-scale attention '''
        q=self.w_q[head_idx](q)
        k=self.w_k[head_idx](k)
        v=self.w_v[head_idx](v)
        out=torch.matmul(torch.matmul(q,k.permute(0,2,1)),v)/self.embed_dim
        return out
    def forward(self,q,k,v):
        output=[]
        for i in range(self.num_heads):
            out=self.single_head(q,k,v,i)
            output.append(out)
        output=torch.cat(output,dim=2)
        output=self.w_o(output)
        print(output.shape)
        return output

if __name__=='__main__':
    x=torch.randn(size=(3,2,8),dtype=torch.float32)
    q,k,v=x,x,x
    att=attention(embed_dim=8,num_heads=4)
    output,attention_weight=att(q,k,v)

2. Use the nn.MultiheadAttention method in Pytoch

In Pytorch, there are 2 required parameters in the MultiheadAttention method:

  embed_dim: Embedding dimension, namely De.

  num_heads: number of heads

  Although it was mentioned earlier that Dq, Dk, Dv, and De can be unequal, the Dq input in pytorch must be equal to De, and the default Dv and De are also equal to De. If the feature dim of k and v is not equal to De, kdim needs to be modified , vdim parameters. For the received data, the default form of pytorch is (seq, batch, feature), that is, the first dimension is the sequence length, the second is the batch size, and the third is the feature dim. If we are used to the (batch, seq, feature) form, we can modify the parameter batch_first =True.

import torch
from torch import nn
q=torch.randint(0,10,size=(10,9,8),dtype=torch.float32) #batch_size,seq_length,dim
k=torch.randint(0,10,size=(10,7,4),dtype=torch.float32)
v=torch.randint(0,10,size=(10,7,3),dtype=torch.float32)
attention=nn.MultiheadAttention(embed_dim=8,num_heads=4,kdim=4,vdim=3,batch_first=True)
attn_output, attn_output_weights=attention(q,k,v)
print(attn_output.shape)
print(attn_output_weights.shape)

Of course, in addition to these parameters, there are more parameters in pytorch's MultiheadAttention, such as various biases, indicating whether to add a bias.

Seven: Summary

  The self-attention machine is a multi-head attention model where all inputs are the same sequence. The multi-head attention structure is a parallel combination of one or more one head attention. Each one head attention consists of scale dot-product attention and three corresponding weight matrices. Multi-head attention, as one of the unit layer types of neural networks, has important applications in many neural network models, and it is also one of the core structures of today's very popular transformer model. Mastering this part is very important for the understanding of transformers. significance.  

Guess you like

Origin blog.csdn.net/weixin_60737527/article/details/127141542