This article will give a brief introduction and distinction of concepts such as Scaled Dot-Product Attention, Multi-head attention , Self-attention , and Transformer . Finally , the code implementation and application of general Multi-head attention are carried out .
1. Concept:
1. Scaled Dot-Product Attention
In practical applications, the Attention mechanism is often used, and the most commonly used is Scaled Dot-Product Attention, which calculates the similarity between the query and the key by calculating the dot product.
- Scaled means that the similarity calculated by Q and K has been quantified to a certain extent, specifically, divided by K_dim under the root sign;
- Dot-Product refers to the similarity between Q and K by calculating the dot product;
- The optional purpose of Mask is to fill the padding part with negative infinity, so that when calculating softmax, the attention is 0, so as to avoid the impact of padding.
2. Multi-head attention
On the basis of Scaled Dot-Product Attention, it is divided into multiple heads, that is, there are multiple Q, K, V to calculate attention in parallel, which may focus on the similarity and weight of different aspects.
3. Self-attention
The self-attention mechanism is an application scenario based on Scaled Dot-Product Attention and Multi-head attention , which means that the source of QKV is the same , and the attention is calculated by oneself and oneself , similar to passing through a linear layer, etc., input and output Equal length.
If the source of QKV is different, it cannot be called self-attention, only attention . For example, KV in GST is a number of randomly initialized tokens, and Q is a frame of the Mel spectrum obtained by the reference encoder. Similarly, Q can also be a randomly initialized one, and KV is from the input, so that a variable length of N input can be calculated attention to get a vector of length 1.
4. Transformer
Transformer refers to a general model framework based on Scaled Dot-Product Attention and Multi-head attention and Self-attention , which includes Positional Encoding, Encoder, Decoder and so on. Transformer is not equal to Self-attention.
Second, the code implementation
The Attention operation is often used, and then the code for Multi-head Attention is organized and implemented, so that the interface can be called directly in the future, of which the single-head attention mechanism is a special case.
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
'''
input:
query --- [N, T_q, query_dim]
key --- [N, T_k, key_dim]
mask --- [N, T_k]
output:
out --- [N, T_q, num_units]
scores -- [h, N, T_q, T_k]
'''
def __init__(self, query_dim, key_dim, num_units, num_heads):
super().__init__()
self.num_units = num_units
self.num_heads = num_heads
self.key_dim = key_dim
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
def forward(self, query, key, mask=None):
querys = self.W_query(query) # [N, T_q, num_units]
keys = self.W_key(key) # [N, T_k, num_units]
values = self.W_value(key)
split_size = self.num_units // self.num_heads
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
## score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5)
## mask
if mask is not None:
## mask: [N, T_k] --> [h, N, T_q, T_k]
mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1)
scores = scores.masked_fill(mask, -np.inf)
scores = F.softmax(scores, dim=3)
## out = score * V
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
return out,scores
3. Practical application:
1. Interface call:
## 类实例化
attention = MultiHeadAttention(3,4,5,1)
## 输入
qurry = torch.randn(8, 2, 3)
key = torch.randn(8, 6 ,4)
mask = torch.tensor([[False, False, False, False, True, True],
[False, False, False, True, True, True],
[False, False, False, False, True, True],
[False, False, False, True, True, True],
[False, False, False, False, True, True],
[False, False, False, True, True, True],
[False, False, False, False, True, True],
[False, False, False, True, True, True],])
## 输出
out, scores = attention(qurry, key, mask)
print('out:', out.shape) ## torch.Size([8, 2, 5])
print('scores:', scores.shape) ## torch.Size([1, 8, 2, 6])