RWKV: A linear transformer model that has both fish and bear's paw

As we all know, transformer and its variants are now crazy in the field of NLP and CV. But the core self-attention mechanism has been criticized because of its O(N2) time complexity (secondary dependency problem).

Under the premise of not changing the overall structure of the transformer block, there are currently two main ideas in the academic world to solve the problem of secondary dependence. One is to achieve linearization of self-attention. There are many works in this area, such as Performer[5], Reformer[6], Linformer[7], Nyströmformer[9], AdaMRA[10], etc. You can learn more about this part of the work in Su Jianlin's blog [8]. Although there is a lot of work on linear attention, refer to the graph of the AdaMRA [10] paper. Only Nyströmformer [9] and AdaMRA [10] can obtain double improvement in speed and effect compared with Transformer, and most of the others need to pay the price of effect to obtain a certain speed increase. But because these two brothers used average pooling as feature clustering, they could not mask future information and thus lost the ability of autoregression. Therefore, the idea of ​​​​improving the speed of the transformer by replacing the linear attention must be paid.

Another idea is to replace self-attention with other linear complexity components. For example, some time ago, Google found that replacing self-attention with dilated convolution can also achieve good results [1]. And MLP-Mixer[2], which is crazy in the CV field, gMLP, aMLP with both CV and NLP capabilities, [3] MLP-Mixer's NLP version Synthesizer[4]. But there are more or less shortcomings. For example, Synthesizer and gMLP are still a bit worse than self-attention in the NLP field. Although the effect of aMLP is better, it still needs to use self-attention, and the purpose of speeding up is still not achieved. However, during the summer vacation this year, the AFT model [11] proposed by Apple claimed to be the fastest transformer model.

The above is the standard AFT formula, where σ is the sigmoid function, QKV is the set of sefl-attention, and w is a trained parameter matrix. It is not difficult to see that AFT is the attention achieved by point multiplication. When doing autoregression, it is only necessary to mask the W matrix. And the W matrix has its own position information, which not only solves the problem that some linear attention cannot do autoregression, but also solves the problem of position encoding in the transformer by the way. It can be said that AFT has achieved three goals with one stone. But success is also Xiao He, and failure is also Xiao He. The W matrix is ​​the core of AFT's success and the biggest shortcoming of AFT. Generally speaking, W should be a square matrix of [max_len, max_len] size. In other words, the length of text that AFT can handle is limited by the size of the W matrix. If you want to process a long text of 10,000 words, the parameter amount of the W matrix will soon catch up with Bert. In order to solve this problem, RWKV, the protagonist of this article, appears below. The original text of RWKV is in RWKV is all you need? A new language model, improved Transformer - Zhihu , but the original text is too short to read and understand. Therefore, the author wrote this article to introduce how RWKV achieves both fish and bear's paw.

RWKV

Overall structure The overall structure of RWKV still adopts the idea of ​​transformer block, and its overall structure is shown in the figure. Compared with the structure of the original transformer block, RWKV replaces self-attention with Position Encoding and TimeMix, and replaces FFN with ChannelMix. The rest is consistent with the transformer.

The position code adopted by Position Matrix RWKV is similar to the form of AliBi code [12]. The author of the original article did not name his position code. In order to facilitate the introduction and reference, the position code mainly considers the characteristics of distance attenuation. This article named it distance code. For the j-th token of the i-th head, its position code is shown in the following formula. Among them, n head represents the number of heads, and max_len represents the maximum length allowed.

The current mainstream view in academia is that the RNN structure is a natural sequential structure and does not require the positional encoding necessary for the transformer model. And if we look at the calculation process of RNN, we can find that RNN only considers the current token and the previous information, and the previous information will gradually decrease as the distance increases. The distance position coding is designed with reference to the timing characteristics of RNN.

However, in the RWKV model, the above calculations are not directly performed on the input X. Instead, a W matrix similar to that in AFT is obtained to participate in subsequent Time-Mix calculations. The shape of the W matrix is ​​[n_head, seq_len, seq_len]. Therefore, for the W matrix, its value is shown in the following formula.

It is not difficult to see from here that the W matrix in AFT is obtained through formulas rather than training in RWKV, so it solves the problem that AFT cannot solve long texts, or the problem of parameter explosion when solving long texts.

Of course, in the case where the length of the task text being processed is limited. For example, machine translation, or application scenarios such as AI writing novels currently used by RWKV. In such application scenarios, more location information can be added to the W matrix because there is no long text. The reference formula is as follows

where and are vectors of shape [n_head, seq_len, 1] and [n_head, 1, seq_len] respectively, and are all 1 matrices at initialization. It will be used as the initialization of the W matrix. After combining this step, formally, the W matrix combines the distance information and relative information in distance coding.

It is worth noting that the original author specially designed a header that does not consider the attenuation of position information when designing distance encoding. That is, the W matrix of the head is an all-one lower triangular matrix.

Time-shit Before introducing TimeMix, we must first introduce the Time-shit technique used by RWKV.

Original: Time-shift: One line of code, improve Transformer performance for free (no parameters, no time-consuming) bazyd

Time-shiit is a trick proposed by the original author to improve the effect of the model with almost zero cost. The implementation code is as follows.

Torch实现
C=x.shape[-1]
self.time_shift = nn.ZeroPad2d((0,0,1,0))
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
Keras实现
d=K.shape(x)[-1]
x=K.concatenate(K.temporal_padding(x,(1,0))[:,:-1,:d//2],x[:,:,d//2:])

It can be seen that no matter which framework it is, it can be realized in two lines, in order to facilitate readers' understanding. Suppose there is a 3x4 matrix.

becomes after time-shift

In fact, it is equivalent to inserting a small RNN. Experiments show that a simple trick can make the model converge faster and better.

TimeMix TimeMix is ​​the part used to replace self-attention in RWKV. Based on AFT, improvements are made with both linear speed and better performance. Before performing this step, the input x needs to be time-shifted.

Like the QKV matrix in self-attention, there is also a corresponding RKV matrix in RWKV. For the j-th token of the i-th header in the output matrix, the calculation steps are as follows.

This is a square matrix of [hiden_size, hiden_size] size, which is used for the final output like regular attention. It is a matrix with the size of [seq_len, hiden_size], and its function is guessed to be similar to that of bias.

ChannelMix ChannelMix is ​​the part used to replace FFN in RWKV. Similar to tiny attention to attention. ChannelMix is ​​essentially a tiny TimeMix.

Before performing this step of calculation, a time-shift must be performed like TimeMix. Then still have to calculate the RKV matrix and W weight. However, the difference is that in this step, assuming that the dimension of the input x is embed_size, the dimension of R should be the same as X. The dimension of KV is the user-defined hidden_size, and the shape of W is [hidden_size, embed_size].

A tiny version of TimeMix can be implemented by setting a smaller hidden_size, which can speed up with less impact on performance. When hidden_size==embed_size, it can be regarded as a TimeMix that does not consider position information and normalization or as a point multiplication FFN.

The specific calculation formula is as follows

Summary This article presents a model that can have your cake and eat it too. It is as versatile and efficient as AFT, and the design of distance position encoding makes the model also have the ability to face ultra-long text.

The actual experimental results can be seen in the original text, and this article only introduces its structure. But in general, the author has tested GPT-based ai writing novels and RWKV-based ai writing novels. In comparison, the articles written by RWKV will be smoother, and the convergence speed will be faster during training.

references

[1] Are Pre-trained Convolutions Better than Pre-trained Transformers https://arxiv.org/pdf/2105.03322.pdf

[2] MLP-Mixer: An all-MLP Architecture for Vision https://arxiv.org/pdf/2105.01601.pdf

[3] Pay Attention to MLPs https://arxiv.org/pdf/2105.08050.pdf

[4] Synthesizer: Rethinking Self-Attention in Transformer Models https://arxiv.org/abs/2005.00743

[5] Rethinking Attention with Performers https://arxiv.org/abs/2009.14794

[6] Reformer: The Efficient Transformer https://arxiv.org/abs/2001.04451

[7] Linformer: Self-Attention with Linear Complexity https://arxiv.org/abs/2006.04768

[8] Exploration of linear Attention: Does Attention have to have a Softmax? https://spaces.ac.cn/archives/7546

[9] Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention https://arxiv.org/abs/2102.03902

[10] Adaptive Multi-Resolution Attention with Linear Complexity https://arxiv.org/abs/2108.04962

[11] An Attention Free Transformer https://arxiv.org/abs/2105.14103

[12] Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

https://arxiv.org/abs/2108.12409




 RWKV: A linear transformer model that has both fish and bear's paw-Knowledge

https://www.youtube.com/watch?v=oaP8_fUFVWw 

Guess you like

Origin blog.csdn.net/u013250861/article/details/131214695