On the Length Extrapolation of LLM

From: NLP Workstation

Enter the NLP group —> join the NLP exchange group

written in front

With the continuous development of large model applications, knowledge plug-ins have become an important means. But the plug-in method is often limited by the acceptable length of the model itself and the extrapolation ability of the model. Today I will bring you a brief discussion on the length extrapolation of LLM, from @uuuuu (zhihu).

https://zhuanlan.zhihu.com/p/645770522

It involves extrapolation strategies updated to 20230724, NBCE, linear interpolation, NTK-Aware Scaled RoPE, Dynamically Scaled RoPE, consistent of Dynamically Scaled RoPE.

Starting from the second one, basically the latter one is optimized based on the previous one, and is applicable to all language models using ROPE.

1. NBCE

NBCE:使用朴素贝叶斯扩展LLM的Context处理长度
https://kexue.fm/archives/9617

The context method of extending LLM first proposed by Su Shen is based on the formula inspired by Bayes: it is 1b02ac182f7b3d104c5ee004507ef078.png4b9d12597a80d091d6b30539b36a5d93.pngreally good in the actual test under the question and answer, and the reading comprehension under the longer context is not bad.

The limitation is disorder, that is, the inability to recognize the input order of Context, which may not perform well in scenarios such as continuation of stories. Doing something that relies on each context to generate answers, such as extracting document summaries, is less effective.

outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        return_dict=True,
                        use_cache=True,
                        past_key_values=past_key_values
                       )
past_key_values = outputs.past_key_values
        
# ===== 核心代码开始 =====
beta = 0.25
probas = torch.nn.functional.softmax(outputs.logits[:, -1], dim=-1)
logits = probas.log()
k = (probas * logits).sum(dim=-1)[1:].argmax() + 1
logits_max = logits[k]
logits_uncond = logits[0]
logits = (1 + beta) * logits_max - beta * logits_uncond
# ===== 核心代码结束 =====
        
# 构建分布,采样
tau = 0.01  # tau = 1是标准的随机采样,tau->0则是贪心搜索
probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)

The code, pictures, and text here are all selected from Science Space.

Second, linear interpolation

https://kaiokendev.github.io/context
https://lmsys.org/blog/2023-06-29-longchat/
https://arxiv.org/abs/2306.15595

llama is pre-trained on the length of 2048 based on rotary embedding. This method achieves the purpose of length extrapolation by compressing the position between 0 and 2048.

Longchat fine-tunes the model so that the context length is expanded to 16384 and the compression ratio is 8. For example, the token with position_ids = 10000 becomes position_ids = 10000 / 8 = 1250, and the adjacent token 10001 becomes 10001 / 8 = 1250.125

The drawback of this method is that a certain amount of fine-tuning is required to allow the model to adapt to this change.

import torch
import transformers
import transformers.models.llama.modeling_llama
from einops import rearrange

from functools import partial

class CondenseRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, ratio, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        # Build here to make `torch.jit.trace` work.
        self.ratio = ratio
        max_position_embeddings *= ratio
        print(f"Condensing Positional embeddings from {max_position_embeddings} to {max_position_embeddings // ratio}")
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) / ratio
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) / self.ratio
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

def replace_llama_with_condense(ratio):
    transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial(CondenseRotaryEmbedding, ratio=ratio)

3. NTK-Aware Scaled RoPE

NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/

RoPE是一种β进制编码:https://spaces.ac.cn/archives/9675

4bbd968cdedaced81343b8724af3d3e1.pngInteresting to explain, RoPE behaves like a clock. A 12-hour clock is basically a RoPE with dimension 3 and base 60. Therefore, every second, the minute hand turns 1/60 of a minute, and every minute, the hour hand turns 1/60 of a minute. Now, if you slow down the time by a factor of 4, that's the linear RoPE scaling of the second use. Unfortunately, every second is now differentiated, because now the second hand hardly moves every second. So if someone gives you two different times with a difference of only one second, you won't be able to tell them apart from a distance. NTK-Aware RoPE extensions do not slow down time. A second is still a second, but it slows down minutes by a factor of 1.5 and hours by a factor of 2. This way, you can fit 90 minutes into an hour and 24 hours into a half day. So now you basically have a clock that measures 129.6k seconds instead of 43.2k seconds. Since the hour hand does not need to be precisely measured when looking at the time, it is crucial to scale the hours to a greater extent than the seconds. Don't want to lose the precision of the seconds hand, but can afford the loss of precision of the minute or even hour hand.

import transformers

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):

    #The method is just these three lines
    max_position_embeddings = 16384
    a = 8 #Alpha value
    base = base * a ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init

四、Dynamically Scaled RoPE

https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

b910796141131a373c83825007d72abb.pngFor the second and third methods above, a hyperparameter α is involved, which is used to adjust the scaling ratio. This method is to dynamically select the correct scaling parameter through the sequence length. The effect can be seen in the figure above.

For linear interpolation, the exact position values ​​for the first 2k contexts, and then recompute the position vector for each new sequence length as the model generates tokens one by one. Essentially, set the scale to original model context length / current sequence length.

For dynamic NTK, the scaling of α is set to (α * current sequence length / original model context length) - (α - 1). Dynamically scale hyperparameters as sequence length increases.

import math
import torch

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
        super().__init__()
        self.ntk = ntk
        self.base = base
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            if self.ntk:
                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
                inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
                self.register_buffer("inv_freq", inv_freq)
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            if not self.ntk:
                t *= self.max_position_embeddings / seq_len
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

五、consistent of Dynamically Scaled RoPE

https://github.com/NormXU/Consistent-DynamicNTKRoPE

29bd113406335f2af9fc2fb9453f8968.pngA problem with method four is that because α is dynamic and decoding is cached, when the 100th token is generated, the calculated α is inconsistent with the 200th token. 7835f4e3c98822f2d57a7bfbcd1d168e.pngThe rotation base of query and key are inconsistent, and the correct one should be like thisa777ea0a39666252dc19c5bd7622ea7e.png

import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import LlamaAttention

def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()

    if self.pretraining_tp > 1:
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
        query_states = torch.cat(query_states, dim=-1)

        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
        key_states = torch.cat(key_states, dim=-1)

        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
        value_states = torch.cat(value_states, dim=-1)

    else:
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

    if past_key_value is not None:
        # reuse k w/o RoPE
        key_states = torch.cat([past_key_value[0], key_states], dim=2)

    # apply RoPE after retrieving all keys and queries
    query_states, rotated_key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    if past_key_value is not None:
        # reuse v, self_attention
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value = (key_states, value_states) if use_cache else None # cache the key w/o RoPE

    # repeat k/v heads if n_kv_heads < n_heads
    rotated_key_states = repeat_kv(rotated_key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    attn_weights = torch.matmul(query_states, rotated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights + attention_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    if self.pretraining_tp > 1:
        attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
    else:
        attn_output = self.o_proj(attn_output)

    if not output_attentions:


        attn_weights = None

    return attn_output, attn_weights, past_key_value


def replace_llama_attn_with_consistent_ntk_rope():
    LlamaAttention.forward = forward

Enter the NLP group —> join the NLP exchange group

Guess you like

Origin blog.csdn.net/qq_27590277/article/details/131990041