Code implementation and application of the multi-head attention mechanism MultiHeadAttention in pytorch

This article will  give a brief introduction and distinction of concepts such as Scaled Dot-Product Attention, Multi-head attention , Self-attention , and Transformer . Finally , the code implementation and application of general Multi-head attention are carried out .

1. Concept:

1. Scaled Dot-Product Attention

In practical applications, the Attention mechanism is often used, and the most commonly used is Scaled Dot-Product Attention, which calculates the similarity between the query and the key by calculating the dot product.

  • Scaled means that the similarity calculated by Q and K has been quantified to a certain extent, specifically, divided by K_dim under the root sign;
  • Dot-Product refers to the similarity between Q and K by calculating the dot product;
  • The optional purpose of Mask is to fill the padding part with negative infinity, so that when calculating softmax, the attention is 0, so as to avoid the impact of padding.

2. Multi-head attention

On  the basis of Scaled Dot-Product Attention, it is divided into multiple heads, that is, there are multiple Q, K, V to calculate attention in parallel, which may focus on the similarity and weight of different aspects.

3. Self-attention

The self-attention mechanism is an application scenario based on Scaled Dot-Product Attention and Multi-head attention , which means that the source of QKV is the same , and the attention is calculated by oneself and oneself , similar to passing through a linear layer, etc., input and output Equal length.

If the source of QKV is different, it cannot be called self-attention, only attention . For example, KV in GST is a number of randomly initialized tokens, and Q is a frame of the Mel spectrum obtained by the reference encoder. Similarly, Q can also be a randomly initialized one, and KV is from the input, so that a variable length of N input can be calculated attention to get a vector of length 1.

4. Transformer

Transformer refers to a general model framework based on Scaled Dot-Product Attention and Multi-head attention and Self-attention , which includes Positional Encoding, Encoder, Decoder and so on. Transformer is not equal to Self-attention.

Second, the code implementation

 The Attention operation is often used, and then the code for Multi-head Attention is organized and implemented, so that the interface can be called directly in the future, of which the single-head attention mechanism is a special case.

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    '''
    input:
        query --- [N, T_q, query_dim] 
        key --- [N, T_k, key_dim]
        mask --- [N, T_k]
    output:
        out --- [N, T_q, num_units]
        scores -- [h, N, T_q, T_k]
    '''

    def __init__(self, query_dim, key_dim, num_units, num_heads):

        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)

    def forward(self, query, key, mask=None):
        querys = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key)  # [N, T_k, num_units]
        values = self.W_value(key)

        split_size = self.num_units // self.num_heads
        querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0)  # [h, N, T_q, num_units/h]
        keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]
        values = torch.stack(torch.split(values, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]

        ## score = softmax(QK^T / (d_k ** 0.5))
        scores = torch.matmul(querys, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores = scores / (self.key_dim ** 0.5)

        ## mask
        if mask is not None:
            ## mask:  [N, T_k] --> [h, N, T_q, T_k]
            mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1)
            scores = scores.masked_fill(mask, -np.inf)
        scores = F.softmax(scores, dim=3)

        ## out = score * V
        out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
        out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0)  # [N, T_q, num_units]

        return out,scores

 3. Practical application:

1. Interface call:

## 类实例化
attention = MultiHeadAttention(3,4,5,1)

## 输入
qurry = torch.randn(8, 2, 3)
key = torch.randn(8, 6 ,4)
mask = torch.tensor([[False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],])

## 输出
out, scores = attention(qurry, key, mask)
print('out:', out.shape)         ## torch.Size([8, 2, 5])
print('scores:', scores.shape)   ## torch.Size([1, 8, 2, 6])

2. The role of mask:

Scores before mask:

Scores after mask:

Scores after softmax:

Guess you like

Origin blog.csdn.net/m0_46483236/article/details/124015298