论文:TRAIN SHORT, TEST LONG: ATTENTION WITH LINEAR BIASES ENABLES INPUT LENGTH EXTRAPOLATION (ICLR2022)
最近,随着大模型的训练,关于token长度的论文和做法也出现了不少,今天在读了ALiBl的做法之后,稍微做个论文笔记。
正弦位置编码
在原始的Transformer论文中,作者采用了正弦位置编码方法( sinusoidal approach ),也就是下方的公式,
这个做法的好处就是,对于固定的偏移量K,比如和当前位置相距K的距离,直接利用 �����+� 就能计算他们的相对值,而且这个计算是线性的、不需要学习。Transformer的作者也尝试了learned postional embedding,不学习和学习的效果相差不大,最后作者选择了sinusoidal approach,因为作者认为它可以简单扩展到更长的训练长度,甚至长于训练时的长度。
至于像BERT中采用绝对位置编码,learned postional embedding我们就不赘述了,绝对位置编码就是0-511,他的问题就是当推理到更长序列时,由于超过当前长度后,embedding没有学习过,可想而知效果会明显下降。
例如BERT的self-attention代码
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
Rotary Position Embeddings
Jianlin Su提出旋转位置编码,这个做法被OPT,LLaMA等所采用,主要的做法就是:在每一层的self-attention计算中,我们对query和key做sinusoidal乘法,如下方代码所示
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
而 llama的embedding我们不需要学习,这样做的好处呢,就是我们不需要在训练中去学习embedding,即可以节省内存消耗,又可以在推理时cover token长度远远长于训练时token的情况,效果也比较好。
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
Jianlin Su, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding, 2021.
不过ALiBi的作者认为呢,这种做法会代码一定的惩罚计算,会让训练和推理的速度稍稍变慢。
T5 Bias
T5里面他直接删掉了position embedding,然后在self-attention计算的时候呢,他先让query和key做矩阵点乘计算得到attention-score,然后添加一个可学习的、共享的bias值,像attention-mask一样,把这个基于距离的值,加到attention-score上,从而体现对相对距离的影响;作者认为这种做法的代价也挺大,就是需要很多的embedding权重参数,训练会变慢很多。
ALiBi
本文的做法是不添加position embedding,然后添加一个静态的不学习的bias,如下图:
怎么理解呢,就是在query和key做矩阵点乘的基础上,加上一个常数负值,比如距离当前位置前1位为-1, 前两位为-2,这些常数要乘上 权重 m
对于有8个attention head的例子,每个头的权重m不同,范围为
如果是16个头,m的范围就是
Conclusion
从实验结果来看ALiBi有几个有点:
1、减少了需要训练的embedding,可以稍微加快训练速度,减小模型参数
2、在512上训练,到更长的token上推理时,表现相比于之前的方法更稳定
3、像MosaicLM用了这种技术可以直接拿来写小说,生成特别长的文本内容。
希望了解细节的同学可以把论文好好品一品,论文整体写的通俗易懂。