From: NLP Workstation
Enter the NLP group —> join the NLP exchange group
written in front
With the continuous development of large model applications, knowledge plug-ins have become an important means. But the plug-in method is often limited by the acceptable length of the model itself and the extrapolation ability of the model. Today I will bring you a brief discussion on the length extrapolation of LLM, from @uuuuu (zhihu).
https://zhuanlan.zhihu.com/p/645770522
It involves extrapolation strategies updated to 20230724, NBCE, linear interpolation, NTK-Aware Scaled RoPE, Dynamically Scaled RoPE, consistent of Dynamically Scaled RoPE.
Starting from the second one, basically the latter one is optimized based on the previous one, and is applicable to all language models using ROPE.
1. NBCE
NBCE:使用朴素贝叶斯扩展LLM的Context处理长度
https://kexue.fm/archives/9617
The context method of extending LLM first proposed by Su Shen is based on the formula inspired by Bayes: it is really good in the actual test under the question and answer, and the reading comprehension under the longer context is not bad.
The limitation is disorder, that is, the inability to recognize the input order of Context, which may not perform well in scenarios such as continuation of stories. Doing something that relies on each context to generate answers, such as extracting document summaries, is less effective.
outputs = model(input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
use_cache=True,
past_key_values=past_key_values
)
past_key_values = outputs.past_key_values
# ===== 核心代码开始 =====
beta = 0.25
probas = torch.nn.functional.softmax(outputs.logits[:, -1], dim=-1)
logits = probas.log()
k = (probas * logits).sum(dim=-1)[1:].argmax() + 1
logits_max = logits[k]
logits_uncond = logits[0]
logits = (1 + beta) * logits_max - beta * logits_uncond
# ===== 核心代码结束 =====
# 构建分布,采样
tau = 0.01 # tau = 1是标准的随机采样,tau->0则是贪心搜索
probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
The code, pictures, and text here are all selected from Science Space.
Second, linear interpolation
https://kaiokendev.github.io/context
https://lmsys.org/blog/2023-06-29-longchat/
https://arxiv.org/abs/2306.15595
llama is pre-trained on the length of 2048 based on rotary embedding. This method achieves the purpose of length extrapolation by compressing the position between 0 and 2048.
Longchat fine-tunes the model so that the context length is expanded to 16384 and the compression ratio is 8. For example, the token with position_ids = 10000 becomes position_ids = 10000 / 8 = 1250, and the adjacent token 10001 becomes 10001 / 8 = 1250.125
The drawback of this method is that a certain amount of fine-tuning is required to allow the model to adapt to this change.
import torch
import transformers
import transformers.models.llama.modeling_llama
from einops import rearrange
from functools import partial
class CondenseRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, ratio, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.ratio = ratio
max_position_embeddings *= ratio
print(f"Condensing Positional embeddings from {max_position_embeddings} to {max_position_embeddings // ratio}")
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) / ratio
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) / self.ratio
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def replace_llama_with_condense(ratio):
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial(CondenseRotaryEmbedding, ratio=ratio)
3. NTK-Aware Scaled RoPE
NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
RoPE是一种β进制编码:https://spaces.ac.cn/archives/9675
Interesting to explain, RoPE behaves like a clock. A 12-hour clock is basically a RoPE with dimension 3 and base 60. Therefore, every second, the minute hand turns 1/60 of a minute, and every minute, the hour hand turns 1/60 of a minute. Now, if you slow down the time by a factor of 4, that's the linear RoPE scaling of the second use. Unfortunately, every second is now differentiated, because now the second hand hardly moves every second. So if someone gives you two different times with a difference of only one second, you won't be able to tell them apart from a distance. NTK-Aware RoPE extensions do not slow down time. A second is still a second, but it slows down minutes by a factor of 1.5 and hours by a factor of 2. This way, you can fit 90 minutes into an hour and 24 hours into a half day. So now you basically have a clock that measures 129.6k seconds instead of 43.2k seconds. Since the hour hand does not need to be precisely measured when looking at the time, it is crucial to scale the hours to a greater extent than the seconds. Don't want to lose the precision of the seconds hand, but can afford the loss of precision of the minute or even hour hand.
import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
#The method is just these three lines
max_position_embeddings = 16384
a = 8 #Alpha value
base = base * a ** (dim / (dim-2)) #Base change formula
old_init(self, dim, max_position_embeddings, base, device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init
四、Dynamically Scaled RoPE
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
For the second and third methods above, a hyperparameter α is involved, which is used to adjust the scaling ratio. This method is to dynamically select the correct scaling parameter through the sequence length. The effect can be seen in the figure above.
For linear interpolation, the exact position values for the first 2k contexts, and then recompute the position vector for each new sequence length as the model generates tokens one by one. Essentially, set the scale to original model context length / current sequence length.
For dynamic NTK, the scaling of α is set to (α * current sequence length / original model context length) - (α - 1). Dynamically scale hyperparameters as sequence length increases.
import math
import torch
class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
super().__init__()
self.ntk = ntk
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
if self.ntk:
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
if not self.ntk:
t *= self.max_position_embeddings / seq_len
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
五、consistent of Dynamically Scaled RoPE
https://github.com/NormXU/Consistent-DynamicNTKRoPE
A problem with method four is that because α is dynamic and decoding is cached, when the 100th token is generated, the calculated α is inconsistent with the 200th token. The rotation base of query and key are inconsistent, and the correct one should be like this
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import LlamaAttention
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if past_key_value is not None:
# reuse k w/o RoPE
key_states = torch.cat([past_key_value[0], key_states], dim=2)
# apply RoPE after retrieving all keys and queries
query_states, rotated_key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse v, self_attention
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None # cache the key w/o RoPE
# repeat k/v heads if n_kv_heads < n_heads
rotated_key_states = repeat_kv(rotated_key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, rotated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def replace_llama_attn_with_consistent_ntk_rope():
LlamaAttention.forward = forward
Enter the NLP group —> join the NLP exchange group