Re-exploration of shared Embedding at the output end of language model
Originally published in Beijing by Su Jianlin PaperWeekly 2023-07-26 23:42
Included in collection
#Natural language processing 358
#141 pre-trained models
When pre-training was just emerging, it was very common to reuse Embedding weights at the output of the language model. For example, BERT, the first version of T5 , and early GPT all used this operation. This is because when the backbone of the model is not large and When the vocabulary is large, the number of parameters in the Embedding layer is considerable. If an independent weight matrix of the same size is added at the output end , it will cause a surge in memory consumption.
However, as the scale of model parameters increases, the proportion of the Embedding layer is relatively smaller. In addition, studies such as "Rethinking embedding coupling in pre-trained language models" [1] show that sharing Embedding may have some negative effects, so now share Embedding practice has become less and less.
This article aims to analyze the problems that may be encountered when sharing embedding weights, and explore how to initialize and parameterize more efficiently. Although shared Embedding seems to be "outdated", it is still an interesting research topic.
shared weight
The practice of reusing Embedding weights at the output of the language model is called "Tied Embeddings" or "Coupled Embeddings" in English. The main idea is that the size of the Embedding matrix is the same as the projection matrix converted from the output to logits (only a turn setting), and since this parameter matrix is relatively large, in order to avoid unnecessary waste, simply share the same weight, as shown in the following figure:
▲ Schematic diagram of Transformer sharing Embedding weights
The most direct consequence of sharing Embedding may be that it will cause the initial loss of pre-training to be very large. This is because we usually use techniques like DeepNorm to reduce the difficulty of training, and they all initialize the residual branch of the model close to zero. In other words, the model approximates an identity function at the initial stage, which makes the initial model equivalent to a 2-gram model sharing Embedding. Next we will deduce the reasons for the large loss of such a 2-gram model, and analyze some solutions.
Preparation
Before officially starting the derivation, we need to prepare some basic conclusions.
First of all, it should be clear that we mainly analyze the results of the initial stage. At this time, the weights are all sampled independently and identically distributed from a distribution with "mean 0 and variance 0", which allows us to pass the expected to estimate some summation results. For example, for , we have
Therefore can take . So how big is the error? We can perceive it by its variance. To do this, we first find its second moment:
If the sampling distribution is a normal distribution, then it can be calculated directly, so
This variance size also represents the degree of approximation, that is to say, the smaller the original sampling variance is, the higher the degree of approximation is. In particular, the common sampling variance is (corresponding to , that is, the unit vector), then substituting into the above formula to get , which means that the higher the dimension, the higher the degree of approximation. In addition, if the sampling distribution is not a normal distribution, you can recalculate , or directly use the result of the normal distribution as a reference result, which is just an estimation anyway.
If is another independent and identically distributed vector, then we can use the same method to estimate the inner product, the result is
as well as
Similarly, if , then the variance is , and the higher the dimension, the higher the degree of approximation. The above two results can be said to be the statistical version of the conclusions in "The Angle Distribution of Two Random Vectors in n-Dimensional Space" [2] and "The Amazing Johnson-Lindenstrauss Lemma: Theory" .
loss analysis
For the language model, a token-by-token meta distribution is finally output, where is the vocabulary size. Assuming that we directly output the uniform distribution, that is, the probability of each token is , then it is not difficult to calculate the cross-entropy loss will be . This also means that a reasonable initialization should not make the initial loss significantly exceed , because represents the most naive uniform distribution, and obviously exceeding is equivalent to saying that it is far worse than a uniform distribution, which is like making a deliberate mistake and is unreasonable.
So, why does this happen to shared Embedding? Assuming that the initial Embedding is , as mentioned above, the residual branch in the initial stage is close to zero, so the input token is input, and the model output is the Embedding after Normalization. Common Normalization is Layer Norm or RMS Norm. Since the initialization distribution is zero-mean, Layer Norm is roughly equivalent to RMS Norm, so the output is
Next, reuse Embedding, inner product and then Softmax, the essence of the established distribution is
The corresponding loss function is
The language model task is to predict the next token, and we know that the proportion of overlapping words in natural sentences is very small, so basically it can be considered that , then according to the result (4) there is . So, the initial loss function is
The following uses formula (1) and formula (4) again. The common initialization variance , or a constant, or (at this time ), no matter which one is, when is larger, will lead to dominate, so the loss will be of level, which easily exceeds the uniform distribution of .
some countermeasures
Based on the above derivation results, we can design some targeted countermeasures. The more direct solution is to adjust the initialization. According to formula (9), we only need to let , then the initial loss becomes level, that is to say, the standard deviation of initialization should be changed to .
Generally speaking, we want the initialization variance of the parameters to be as large as possible, so that the gradient is relatively less likely to underflow, and sometimes it will appear too small. For this reason, we can change the way of thinking: obviously, the reason why the formula (9) is too large is because it appears, because the two are the same, their inner product becomes the modulus length, and thus becomes very large, if we can Make them different, then this one dominant term will not appear.
For this reason, the easiest way is naturally not to share Embedding at all. At this time, instead of , use (4) instead of (1) as an approximation, so formula (9) is asymptotic to . If we still want to keep the shared Embedding, we can connect an orthogonally initialized projection layer after the final Normalization, so that becomes , according to the Johnson-Lindenstrauss lemma , the randomly projected vectors are similar to independent vectors, so also Similar to the situation of not sharing, this is actually BERT's solution. In particular, this projection layer can also be generalized to add bias and activation functions.
If you don't want to introduce a little bit of extra parameters, you can consider "disrupting" each dimension after Normalization,
Here is a splicing operation, then and are also close to orthogonal, and the inner product is naturally approximately equal to 0. This is equivalent to (in the initial stage) splitting the original Embedding matrix into two matrices and then building a 2-gram model that does not share Embedding. In addition, we can also consider other scrambling operations, such as first reshape in ShuffleNet [3], then transpose and then reshape back.
In the author's experiment, directly changing the initialization standard deviation to the convergence speed is the slowest, and the other methods converge at the same speed. As for the final effect, all methods seem to be similar.
Article Summary
This paper reviews the operation of sharing Embedding weights at the output of the language model, deduces the possibility that direct reuse of Embedding to project the output may cause excessive loss, and discusses some solutions.
references
[1] https://arxiv.org/abs/2010.12821
[2] https://kexue.fm/archives/7076
[3] https://arxiv.org/abs/1707.01083
read more
#Contribution channel#
Let your text be seen by more people
How can more high-quality content reach readers in a shorter path and shorten the cost for readers to find high-quality content? The answer is: people you don't know.