Re-exploration of shared Embedding at the output end of language model

Re-exploration of shared Embedding at the output end of language model

Originally  published in Beijing by Su Jianlin PaperWeekly  2023-07-26 23:42 

Included in collection

#Natural language processing 358

#141 pre-trained models

picture

When pre-training was just emerging, it was very common to reuse Embedding weights at the output of the language model. For example, BERT, the first version of T5 , and early GPT all used this operation. This is because when the backbone of the model is not large and When the vocabulary is large, the number of parameters in the Embedding layer is considerable. If an independent weight matrix of the same size is added at the output end , it will cause a surge in memory consumption.

However, as the scale of model parameters increases, the proportion of the Embedding layer is relatively smaller. In addition, studies such as "Rethinking embedding coupling in pre-trained language models" [1] show that sharing Embedding may have some negative effects, so now share Embedding practice has become less and less.

This article aims to analyze the problems that may be encountered when sharing embedding weights, and explore how to initialize and parameterize more efficiently. Although shared Embedding seems to be "outdated", it is still an interesting research topic.

picture

shared weight

The practice of reusing Embedding weights at the output of the language model is called "Tied Embeddings" or "Coupled Embeddings" in English. The main idea is that the size of the Embedding matrix is ​​the same as the projection matrix converted from the output to logits (only a turn setting), and since this parameter matrix is ​​relatively large, in order to avoid unnecessary waste, simply share the same weight, as shown in the following figure:

picture

▲ Schematic diagram of Transformer sharing Embedding weights

The most direct consequence of sharing Embedding may be that it will cause the initial loss of pre-training to be very large. This is because we usually use techniques like  DeepNorm  to reduce the difficulty of training, and they all initialize the residual branch of the model close to zero. In other words, the model approximates an identity function at the initial stage, which makes the initial model equivalent to a 2-gram model sharing Embedding. Next we will deduce the reasons for the large loss of such a 2-gram model, and analyze some solutions.

picture

Preparation

Before officially starting the derivation, we need to prepare some basic conclusions.

First of all, it should be clear that we mainly analyze the results of the initial stage. At this time, the weights are all sampled independently and identically distributed from a distribution with "mean 0 and variance 0", which allows us to pass the expected to estimate some summation results. For example, for , we have

picture

Therefore can take . So how big is the error? We can perceive it by its variance. To do this, we first find its second moment:

picture

If the sampling distribution is a normal distribution, then it can be calculated directly, so

picture

This variance size also represents the degree of approximation, that is to say, the smaller the original sampling variance is, the higher the degree of approximation is. In particular, the common sampling variance is (corresponding to , that is, the unit vector), then substituting into the above formula to get , which means that the higher the dimension, the higher the degree of approximation. In addition, if the sampling distribution is not a normal distribution, you can recalculate , or directly use the result of the normal distribution as a reference result, which is just an estimation anyway.

If is another independent and identically distributed vector, then we can use the same method to estimate the inner product, the result is

picture

as well as

picture

Similarly, if , then the variance is , and the higher the dimension, the higher the degree of approximation. The above two results can be said to be the statistical version of the conclusions in "The Angle Distribution of Two Random Vectors in n-Dimensional Space" [2] and "The Amazing Johnson-Lindenstrauss Lemma: Theory" .

picture

loss analysis

For the language model, a token-by-token meta distribution is finally output, where is the vocabulary size. Assuming that we directly output the uniform distribution, that is, the probability of each token is , then it is not difficult to calculate the cross-entropy loss will be . This also means that a reasonable initialization should not make the initial loss significantly exceed , because represents the most naive uniform distribution, and obviously exceeding is equivalent to saying that it is far worse than a uniform distribution, which is like making a deliberate mistake and is unreasonable.

So, why does this happen to shared Embedding? Assuming that the initial Embedding is , as mentioned above, the residual branch in the initial stage is close to zero, so the input token is input, and the model output is the Embedding after Normalization. Common Normalization is Layer Norm or RMS Norm. Since the initialization distribution is zero-mean, Layer Norm is roughly equivalent to RMS Norm, so the output is

picture

Next, reuse Embedding, inner product and then Softmax, the essence of the established distribution is

picture

The corresponding loss function is

picture

The language model task is to predict the next token, and we know that the proportion of overlapping words in natural sentences is very small, so basically it can be considered that , then according to the result (4) there is . So, the initial loss function is

picture

The following uses formula (1) and formula (4) again. The common initialization variance , or a constant, or (at this time ), no matter which one is, when is larger, will lead to dominate, so the loss will be of level, which easily exceeds the uniform distribution of .

picture

some countermeasures

Based on the above derivation results, we can design some targeted countermeasures. The more direct solution is to adjust the initialization. According to formula (9), we only need to let , then the initial loss becomes level, that is to say, the standard deviation of initialization should be changed to .

Generally speaking, we want the initialization variance of the parameters to be as large as possible, so that the gradient is relatively less likely to underflow, and sometimes it will appear too small. For this reason, we can change the way of thinking: obviously, the reason why the formula (9) is too large is because it appears, because the two are the same, their inner product becomes the modulus length, and thus becomes very large, if we can Make them different, then this one dominant term will not appear.

For this reason, the easiest way is naturally not to share Embedding at all. At this time, instead of , use (4) instead of (1) as an approximation, so formula (9) is asymptotic to . If we still want to keep the shared Embedding, we can connect an orthogonally initialized projection layer after the final Normalization, so that becomes , according to the  Johnson-Lindenstrauss lemma , the randomly projected vectors are similar to independent vectors, so also Similar to the situation of not sharing, this is actually BERT's solution. In particular, this projection layer can also be generalized to add bias and activation functions.

If you don't want to introduce a little bit of extra parameters, you can consider "disrupting" each dimension after Normalization,

picture

Here is a splicing operation, then and are also close to orthogonal, and the inner product is naturally approximately equal to 0. This is equivalent to (in the initial stage) splitting the original Embedding matrix into two matrices and then building a 2-gram model that does not share Embedding. In addition, we can also consider other scrambling operations, such as first reshape in ShuffleNet [3], then transpose and then reshape back.

In the author's experiment, directly changing the initialization standard deviation to the convergence speed is the slowest, and the other methods converge at the same speed. As for the final effect, all methods seem to be similar.

picture

Article Summary

This paper reviews the operation of sharing Embedding weights at the output of the language model, deduces the possibility that direct reuse of Embedding to project the output may cause excessive loss, and discusses some solutions.

picture

references

picture

[1] https://arxiv.org/abs/2010.12821

[2] https://kexue.fm/archives/7076

[3] https://arxiv.org/abs/1707.01083

read more

picture

picture

picture

picture

#Contribution channel#

 Let your text be seen by more people 

How can more high-quality content reach readers in a shorter path and shorten the cost for readers to find high-quality content? The answer is: people you don't know.

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/132000382