LLaMa principle + source code - dismantling (KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU)

principle

The difference between Vanilla Transformer and LLaMa

Vanilla Transformer 与 LLaMaComparison: LLaMa is different from the ordinary Transformer architecture, including the use of pre -normalization (Pre-normalization) and the use of RMSNorm normalization function (Normalizing Function), the use of rotation position embedding (RoPE) , activation The function is replaced from ReLU to SwiGLU , and the self-attention is improved to use KV-Cache's Grouped Query . The overall Transformer architecture is similar to GPT-2.
Insert image description here

LLaMa -> Alpaca -> VicunaThe evolution of:

  • LLaMa : Meta open source Pre-trained Model, 模型参数从7B、13B、32B、65B不等LLaMa-7B surpasses Text-davinci-003 (i.e. GPT3-173B) in most benchmark tests. Compared with ChatGPT or GPT4, LLaMa may still have a gap in effectiveness. , currently hugging face has integrated LLaMa’s code implementation and open source model. Both academia and industry can build on this foundation for learning and research.
    Insert image description here

  • Alpaca在LLaMa-7B的基础上监督微调 : A model from Stanford . Stanford used OpenAI's Text-davinci-003(即GPT3-173B)的APIcooperation self-instructtechnology to automatically generate a 52K prompt-reply instruction data set using 175 prompt seeds. The model was fine-tuned on LLaMa-7B and ran on 8 80G A100s. Training took 3 hours.

  • Vicuna : 在LLaMa-13B的基础上使用监督微调The obtained model, the data set comes from ShareGPT 产生的用户对话数据, a total of 70K items. Used Pytorch FSDP to train on 8 A100s for one day. 序列长度由512扩展到了2048Compared with Alpaca, Vicuna uses gradient detection and flash attention to solve the memory problem during training ; adjusts the training loss to consider multiple rounds of dialogue, and fine-tunes only based on the output of the model. Through GPT4 for scoring and evaluation, Vicuna can achieve 90% of the effect of ChatGPT.

  • LLaMa2 : Adopts most of the pre-training settings and model architecture of Llama 1. The biggest difference between LLaMa2 and LLaMa1 is that 增加了文本长度in 训练34B、70Bthe model 应用了GQA.
    Insert image description here

Embedding

Embedding process : word -> token_id -> embedding_vectorThe first step of conversion is performed using the tokenizer's vocabulary , and the second step of conversion is performed using the learnable Embedding layer .

Insert image description here

RMS Norm

Comparison Batch Norm 和 Layer Norm : Both subtract the mean Mean and divide by the variance Var, which will eventually be normalized to a normal distribution N(0,1). It's just that the two calculate the mean and variance in different dimensions (batch or feature) (where subtracting the mean: re-centeringtransforming the mean to 0, dividing the variance: re-scalingtransforming the variance varance to 1).
Insert image description here

RMS Norm(Root Mean Layer Norm): RMS Norm believes that the reason for the success of Layer Norm is re-scalingthat the mean Mean is used in the calculation process of the variance Var. Therefore, RMS Norm no longer uses the mean Mean, but constructs a special statistic RMSto replace the variance Var. Why use RMS Norm? (1) The calculation amount of RMS Norm is smaller. (2) The effect of RMS is as good as Layer Norm.

The RMS Norm function calculation formula for the input vector a is as follows:

Insert image description here

In addition, RMSNorm can also introduce learnable scaling factors gi and offset parameters bi, thus obtaining

Insert image description here

The code implementation of RMSNorm in the HuggingFace Transformer library is as follows:

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps # eps 防止取倒数之后分母为0
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # weight 是末尾乘的可训练参数, 即g_i
        return (self.weight * hidden_states).to(input_dtype)

In order to make the model training process more stable, GPT-2 proposes to put Layer Norm in front compared to GPT , moving the first layer normalization to before the multi-head self-attention layer, and the second layer normalization also moves Before reaching the fully connected layer, the position of the residual connection is also adjusted to after the multi-head self-attention layer and the fully connected layer. The RMSNorm normalization function is also used in layer normalization.

Rotary Positional Encodding

The usage process of ordinary absolute Positional Encodding : word -> token_id -> embedding_vector + position_encodding -> Encoder_Input, in which the first step of conversion is performed using the tokenizer's vocabulary , and the second step of conversion is performed using the learnable Embedding layer . The obtained embedding_vector and position_encodding are added element-wise, and then sent to the LLM encoder as input.

Insert image description here
ComparedAbsolute PE 和 Relative PE :

  • Absolute PE 绝对位置编码: The PE of one token at a time. There is no relationship between the PEs of each token. It is a fixed set of vectors that reflect the absolute position of each token in the sequence .
  • Relative PE 相对位置编码: PE that handles two tokens at a time is only used when calculating attention ( query@keyadded to the key at the time), reflecting the correlation of the two tokens .

Insert image description here

旋转位置编码(RoPE): RoPE relies on the idea of ​​​​plural numbers, and the starting point is to achieve relative position encoding through absolute position encoding . The goal is to add absolute position information m and n to q and k through the following f operation to obtain ˜qm and ˜kn, and then perform q@k :

Insert image description here

In fact, we used it 复数的思想and found a g 运算way to merge f 运算and q@kthese two operations , so that we only need token qand kand the sum of their absolute positions in the sequence :mn

Insert image description here
It can be seen that unlike ordinary relative position encoding, rotated relative position encoding is used to emphasize the relative position between each token for attention_score after calculation :attention_score=q@k

Why is it called rotational position encoding? Because of the use of 欧拉公式construction 旋转矩阵, the calculation result of q@k is rotated to the corresponding position in space, and position information is added to the calculation result .
Insert image description here
The above is a 2-dimensional example, with only 2 token xmsums xn. LLaMa is n-dimensional, and the same operation is performed for n tokens:
Insert image description here

Since the above rotation matrix Rn is sparse and has a large number of elements that are 0, 逐位相乘⊗the operation can be used to further speed up the calculation.

Insert image description here

The code implementation of RoPE in the HuggingFace Transformer library is as follows:

class LlamaRotaryEmbedding(torch.nn.Module):

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device,
        dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation
        # in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
        
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`.
        # Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation
            # in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype),
            persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype),
            persistent=False)
    
        return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )
    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
        cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
        sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

SwiGLU Function

The SwiGLU activation function was proposed by Shazeer in the literature and has been widely used in models such as PaLM, and has achieved good results. Compared with the ReLU function, it has been improved a lot in most evaluations . In LLaMA, the calculation formula of the fully connected layer using FFN (Position-wise Feed-Forward Network) with SwiGLU activation function is as follows:

Insert image description here

Among them, σ(x) is the Sigmoid function . The figure below shows the shape of the Swish activation function under different values ​​of parameter β. It can be seen that when β approaches 0, the Swish function approaches the linear function y = x. When β approaches infinity, the Swish function approaches the ReLU function. When β approaches 1, the Swish function is smooth and Not monotonic.

Insert image description here
Swish β = 1 Swish_{\beta=1} in HuggingFace's Transformer librarySwishβ = 1The function uses the SILU function instead.

KV-Cache

First, let’s take a look at the training of LLama ( lower word prediction task ): The generation of seq2seq, but iterates T times and seq_lengradually increases .
Insert image description here

Self-Attention when predicting the next sentence:

  • When timpstep=1 seq_len=1, when [SOS] is given, Love is predicted;
    Insert image description here
  • When timpstep=2 seq_len=2, give [SOS] and Love, predict that
    Insert image description here
  • When timpstep=4 seq_len=4, give [SOS] and Love and can and quickly, predict seize...
    Insert image description here

We only focus on the generated ones最后一个token for each timestep , but because LLaMa is a seq2seq model, the previous token must be recalculated and generated each time, so we hope to cache the tokens calculated and generated by the previous timestep, so that the next timestep does not need to be repeated Computing , in this context, KV-Cache was born.

Let’s analyze again, what exactly do we need in the self-attention of each timestep : Because we only focus on the last tokenattention_output , as shown in the figure below timestep=4, we only need the 4th token of attention_output.

Therefore, we only need to multiply the last token of Q and all the tokens of K to get the last token attention_score, and then use all the tokens of V to attention_scoredot product (multiply and sum) to get the last token attention_output:
Insert image description here
From the above analysis, we can see , for each timestep, our Q only needs the newly added token, while K and V need to cache the tokens of the previous timestep to ensure that the tokens are complete . The attention_output calculated each time is the attention of the newly added token. This saves a lot of computational overhead.

Insert image description here

Insert image description here
Insert image description here

Grouped Multi-Query Attention

Looking back at the original Multi-Head Attention : the bottleneck of time overhead lies in matrix operationsmatrix computation .

Insert image description here

When we use KV-Cache : the bottleneck of time overhead is memory accessmemory access .

Insert image description here

Multi Query Attention

Multi-query attention ( Multi Query Attention,MQA) is a variant of multi-head attention. The main difference is that in multi-query attention, different attention heads share a set of keys and values, and each head only retains a separate copy of query parameters. In terms of specific operations, 去除 K和V 的head维度,只为Q保留head维度. So this is why it is called Multi Query Attention.

Insert image description here

Therefore, there is only one copy of the K and V matrices (regardless of head), which greatly reduces the memory usage and makes it more efficient. Since multi-query attention changes the structure of the attention mechanism, models usually need to support multi-query attention from the beginning of training.

The research results show that multi-query attention support can be added by fine-tuning the already trained model, and only about 5% of the original training data volume is needed to achieve good results. Many models, including Falcon, SantaCoder, StarCoder, etc., use multi-query attention mechanisms.

Insert image description here

Taking LLM Foundry as an example, the multi-query attention implementation code is as follows. Compared with the multi-head self-attention code implemented in LLM Foundry, the difference is only in the establishment of the Wqkv layer:

class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        device: Optional[str] = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.Wqkv = nn.Linear( # Multi-Query Attention 创建
            d_model,
            d_model + 2 * self.head_dim, # 只创建查询的头向量,所以只有1 个d_model
            device=device, # 而键和值则共享各自的一个head_dim 的向量
        )
        self.attn_fn = scaled_multihead_dot_product_attention
        self.out_proj = nn.Linear(
            self.d_model,
            self.d_model,
            device=device
        )
        self.out_proj._is_residual = True # type: ignore
    def forward(
        self,
        x,
    ):
        qkv = self.Wqkv(x) # (1, 512, 960)
        query, key, value = qkv.split( # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
            dim=2 # value -> (1, 512, 96)
        )
        context, attn_weights, past_key_value = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            multiquery=True,
    )
        return self.out_proj(context), attn_weights, past_key_value
Grouped Multi-Query Attention

It is to group the inputs based on Multi-Query Attention. Each group has its own K, V, and long Q.

Insert image description here

Source code

[LLMs Practice] 01 Overall introduction of llama, alpaca, and vicuna and llama reasoning process

Guess you like

Origin blog.csdn.net/weixin_54338498/article/details/135269411