Detailed explanation of ELECTRA model

The innovation of the ELECTRA model can be simply understood as the innovation of its pre-training method. At the moment when Transformer is hot, the difference between large pre-training models lies in the different pre-training tasks they choose.

Masked language modeling (MLM) pre-training methods like BERT have produced excellent results on downstream NLP tasks, but they are computationally intensive to be effective. These methods corrupt the input by replacing some tokens with [MASK], and then train the model to restore the original tokens.

ELECTRA proposes a more sample-efficient pre-training task called replaced token detection (RTD). In this method, the input is corrupted by replacing part of the input tokens with reasonable alternative tokens sampled from the generator. Then train a discriminator model that can predict whether each tokrn in the input is replaced by the generator, instead of training a model that predicts the original token of the replaced token.


ELECTRA structure

ELECTRA consists of two parts, the generator generator and the discriminator discriminator, both of which use the encoder structure of the transformer, but the size of the two is different:

Generator

The generator is a small masked language model (usually 1/4 of the size of the discriminator), which uses the classic bert's MLM method, with specific functions:

1. First randomly select 15% of the tokens and replace them with [MASK] tokens, (the operations of 80% [MASK], 10% unchange, and 10% random replaced by bert are canceled, and bert does this to ease the pre-training and fine-tuning. mismatch, but it is not necessary in electra, because electra uses its discriminator part during finetuning, so there is no mismatch);

2. Use the generator to train the model so that the model can predict the masked token and get the corrupted token;

3. The objective function of the generator is the same as that of bert, hoping that the masked token can be restored to the original original token;

As shown in the figure above, the generator randomly masks two tokens, namely the and cooked . After the generator predicts, the corresponding corrupted token is obtained. The prediction is successful, and cooked becomes ate .

Discriminator

The input of the discriminator is the output of the generator after corrupting the tokens, such as the chief ate the meal in the above example. The role of the discriminator is to distinguish whether each input token is original or replaced. Note: if the token generated by the generator is consistent with the original token , then this token is still original, as in the above example . So for each token, the discriminator will perform a binary classification, and finally obtain the loss on all tokens.


Model extension

weight sharing

There is weight sharing between the generator and the discriminator, but not all parameters are shared. If this is the case, the size of the two needs to be the same, so the model only shares the embedding weight of the generator.

Why do you choose to share the embedding weight? The main reason is that the generator is trained by MLM. MLM predicts the token according to the context around the token, and can learn the embedding representation very well.

Builder Size Selection

In the paper, the author adopts the method of joint training of the two modules of the generator and the discriminator for model training. It is mentioned in the article that if the generator is too strong, the discriminator cannot be successfully trained. This is actually easy to understand, because if the generator is very strong, then the tokens it predicts are very good, that is, they are all original tokens, then the discrinator does not need to learn to converge, because it only needs to consider all binary classifications as 1 Just do it (assuming 1 represents real).

Therefore, the generator cannot be too large, otherwise it will be too powerful. And if it is as large as the discrinator, then the model training once is equivalent to training the parameters of two MLMs, and the effect of efficiency cannot be achieved.

As can be seen from the figure, the model works best when the generator size is 1/4-1/2 the size of the discriminator.

Comparison of training methods

Other training algorithms for ELECTRA were also proposed in the paper, although these ultimately did not improve the results.

Two-Stage ELECTRA:

  1. Use only to train the generator for n steps.

  1. Initialize the weights of the discriminator with the weights of the generator, then use to train the discriminator for n steps, and freeze the weights of the generator.

Note that the weight initialization in this process requires the generator and discriminator to be the same size. We found that without weight initialization, the discriminator sometimes even fails to learn, possibly because the generator starts much earlier than the discriminator. On the other hand, joint training naturally provides a lesson for the discriminator, making the generator start weaker but get better throughout training.

As can be seen from the figure, the downstream task performance improves significantly after switching from generative to discriminative targets during two-stage training, but ultimately falls short of joint training. In addition, it can also be seen that although the anti-ELECTRA is higher than BERT, it is still not as good as joint training.

Why does joint training make the model work better? In fact, we can regard the generator as the questioner and the discriminator as the answerer: during the training process of the model, the questions given by the questioner become more and more advanced, and the answerers become more and more powerful as they accumulate, instead of At the beginning, the questions posed by the question makers were very complicated, and the answerers could not learn at all.

在 BERT中,mask是随机的,很容易会出现mask的token是非常简单的。然而,在ELECTRA中,corrupted token是有一定难度的,而不是简单的mask,所以使discriminator能更好的学习。

比如说,输入是:一个聪明的模型,如果随机mask就是:一[MASK]聪明[MASK]模型, 那么对模型来说很简单。而一个[MASK][MASK]的模型,对模型来说就更复杂。使用高质量的mask进行训练,那么模型就能学得更好。

discriminator 的二分类模型,将MLM连接在一起,它不需要考虑到每个position的数据分布,能够达到更高效训练的成果。

比如小时候学习语文,老师为了加深学生对汉语的理解,总是给出一段话,把一些词去掉(当然老师会有目的性的选词,BERT是随机的),让学生根据上下文来填写空缺词。学生可能会很快地根据上下文或者常识填好空缺词(MLM)。这时,语文老师加大了难度,给出一段话,让学生挑出这段话中哪里地方用词不当。这就是ELECTRA判别器的预训练任务(RTD)。


模型有效性分析

为了更好地了解 ELECTRA 的收益来自何处,作者比较了一系列其他的预训练目标:

  1. ELECTRA 15%: 判别器的损失只来自于 15%的 tokens,即来自于被替换的tokens而不是所有tokens;

  1. Replace MLM: 与MLM类似,但是用生成器模型生成的标记而不是[MASK]去替换 token,这测试了 ELECTRA 对MLM的两阶段 mismatch 问题的解决效果;

  1. All-Tokens MLM: 和Replace MLM类似,不过模型可以预测所有tokens的身份,而不仅仅是被掩盖的tokens。

首先,可以发现,对所有输入token(而不只是一个子集)计算损失,ELECTRA 将从中受益匪浅:ELECTRA 15%的性能要比 ELECTRA 差得多。

其次,可以发现 BERT 的两阶段 mismatch 会稍微损害其性能,因为 Replace MLM 的性能略好于BERT。BERT已经包含了一种技巧,以帮助改善预训练和微调时的差异:被屏蔽的token在 10% 的时间内被替换为随机token,并在 10% 的时间内保持不变。但是,实验结果表明,这些简单的启发式方法不足以完全解决问题。

最后,我们发现 All-Tokens MLM 填补了 BERT 和 ELECTRA 之间的差异。

总体而言,以上结果表明,ELECTRA 的大量改进主要是由于模型对所有token的学习,另一小部分是因为缓解了两阶段的失配。


ELECTRA VS BERT

ELECTRA的创新点:

  1. 提出了新的模型预训练的框架,采用generator和discriminator的结合方式,但又不同于GAN;

  1. 将Masked Language Model的方式改为了replaced token detection;

  1. 因为masked language model根据token周围的context预测该token,能有效地学习到context的信息,所以能很好地学习embedding,因此使用了weight sharing的方式将generator的embedding的信息共享给discriminator;

  1. dicriminator 预测了generator输出的每个token是不是original的(二分类),从而高效地更新transformer的各个参数,使得模型的收敛速度加快;

  1. ELECTRA采用了小的generator以及discriminator的方式共同训练,并且采用了两者loss相加,使得discriminator的学习难度逐渐地提升,学习到更难的token(plausible tokens);

  1. 模型在fine-tuning 的时候,丢弃generator,只使用discrinator;

BERT的不足:

  1. BERT的MLM的实现,并不是非常高效的,只有15%的tokens对参数的更新有用,其他的85%不参与gradients的update;

  1. BERT存预训练和fine-tuning的mismatch,因为在fine-tuning阶段,并不会有[MASK]的token。


整理:

https://zhuanlan.zhihu.com/p/118135466?utm_id=0

https://zhuanlan.zhihu.com/p/90494415

Guess you like

Origin blog.csdn.net/fzz97_/article/details/129473257