REALM: Retrieval-Augmented Language Model Pre-Training

REALM: Retrieval-Augmented Language Model Pre-Training


在这里插入图片描述

从ELMO、BERT和GPT到如今多种多样的预训练模型的横空出世,pre-training + fine-tune逐渐成为了NLP中建模的新范式,众多研究人员也不断的针对于Transformer和预训练做出改进,其中包括对Transformer本身结构的改进、预训练策略的调整、预训练结合多任务学习……甚至已有文章将其和视觉等其他的领域任务进行融合,同样取得了不错的效果。

本文提出了一种利用检索技术增强的预训练模型REALM,目的是希望通过中间步骤的检索任务来增强模型的学习能力,同时增强模型的可解释性。总所周知,预训练模型两个显著的特点便是:在大规模的文本数据上的无监督训练依靠强大的算力。因此,当模型训练所依赖的数据量极大时,为了使模型更好的学习到其中的知识,一种暴力的方法便是不断的增大模型的容量和提升算力。但即便有条件可以做到这些,模型所学到的东西只是隐式的保存在模型的众多参数中,我们并不能以一种直观的方式明白模型是否真正的学到了,以及它学到了什么。

因此,本文在pre-training和fine-tune两个阶段都增加了一个知识检索(knowledge retrieval)的步骤。

  • 预训练时,模型利用MLM策略进行预测MASK的区域正确的内容前,首先通过检索模型从文档集中检索相关的文档,然后利用检索结果中TOP-K文档中的内容来进行正确的预测。
  • fine-tune阶段,模型同样是先从完整的文档集中进行相应的检索,最后利用检索结果来完成Open-QA任务,从而实现对于模型的fine-tune。

由于检索过程的存在,模型通过retrieve-then-predict的过程提供了一种间接且直观化的方供我们理解模型的学习过程,而且检索模型和预训练模型是共同训练的。当检索模型可以正确的检索到对应的文档时,模型就可以更好的进行预测和完成QA任务;反之如何模型利用检索的结果很好的实现了后续的任务,它就会给检索模型一个正的反馈,从而促使模型不断朝好的方向发展。


对于之前的预训练模型的pre-training和fine-tune两个阶段来说,模型的目标都是希望通过 x x 来预测正确的 y y ,即计算条件概率 p ( y x ) p(y|x) 。其中预训练阶段的 x x 是被MASK策略处理后的文本, y y 为MASK的部分;fine-tune阶段的 x x 为问题(question), y y 为对应的答案(answer)

针对于Open-QA任务而言

而REALM通过retrieve-then-predict的过程将 p ( y x ) p(y|x) 的计算分解为计算 p ( z x ) p(z|x) p ( y x , z ) p(y|x,z) 两部分。首先从知识库中检索相应的文档 z z ,即相当于计算 p ( z x ) p(z|x) ,然后根据 x x 和检索结果 z z 来共同预测 y y ,即计算 p ( y x , z ) p(y|x,z) 。因此,模型整体的流程可表示为: p ( y x ) = z Z p ( y x , z ) p ( z x ) p(y|x) = \sum_{z \in Z} p(y|x,z)p(z|x)

通过上式可以看出,模型整体从判别式转换到了生成式。

如何理解生成模型和判别模型

下面从和 p ( z x ) p(z|x) p ( y x , z ) p(y|x,z) 计算相对应的模型的两个关键部分:知识检索模型(neural knowledge retriever)和知识增强编码器(Knowledge-Augmented Encoder)出发来理解REALM。

在预测阶段之前,模型需要完成从完整的文档集中检索出相关的部分文档。REALM这里使用了一种简单的检索模型 – Dense Inner Product Model来完成相应的工作。首先分别将 x x 和文档集中的文档转换为对应的表示向量,然后通过两者的内积来计算 p ( z x ) p(z|x) ,从而找出和 x x 最为相关的部分文档。 p ( z x ) = exp f ( x , z ) z exp f ( x , z ) f ( x , z ) = Embed input ( x ) T Embed doc ( z ) p(z|x) = \frac{\exp f(x,z)}{\sum_{z'}\exp f(x,z')} \\ f(x,z) = \text{Embed}_{\text{input}}(x)^{T}\text{Embed}_{\text{doc}}(z)

其中
 Embed input  ( x ) = W input  BERT CLS  ( j oin BERT  ( x ) )  Embed  doc  ( z ) = W doc BERT  CLS  (  join  BERT  ( z title  , z body  ) ) \begin{array}{l} {\text { Embed input }(x)=\mathbf{W}_{\text {input }} \operatorname{BERT}_{\text {CLS }}\left(j \operatorname{oin}_{\text {BERT }}(x)\right)} \\ {\quad \text { Embed }_{\text {doc }}(z)=\mathbf{W}_{\text {doc BERT }_{\text {CLS }}\left(\text { join }_{\text {BERT }}\left(z_{\text {title }}, z_{\text {body }}\right)\right)}} \end{array} join B E R T ( x ) = [ C L S ] x [ S E P ] join B E R T ( x 1 , x 2 ) = [ C L S ] x 1 [ S E P ] x 2 [ S E P ] \begin{aligned} \operatorname{join}_{\mathrm{BERT}}(x) &=[\mathrm{CLS}] x[\mathrm{SEP}] \\ \operatorname{join}_{\mathrm{BERT}}\left(x_{1}, x_{2}\right) &=[\mathrm{CLS}] x_{1}[\mathrm{SEP}] x_{2}[\mathrm{SEP}] \end{aligned}

当得到检索结果 z z 后,模型需要使用 z z x x 来预测 y y 。由于pre-training和fine-tune阶段需要解决的任务不同,因此两阶段计算 p ( y z , x ) p(y|z,x) 有所区别。

  • pre-training:这里采用和BERT一致的MLM的策略进行建模,因此模型需要预测每个被 [ M A S K ] [MASK] 标记部分的内容。


    在这里插入图片描述

    具体的计算 p ( y z , x ) p(y|z,x) 可表示为: p ( y z , x ) = j = 1 J x p ( y j z , x ) p ( y j z , x ) exp ( w j BERT MaSK ( j ) ( join B E R T ( x , z b o d y ) ) ) \begin{aligned} p(y | z, x) &=\prod_{j=1}^{J_{x}} p\left(y_{j} | z, x\right) \\ p\left(y_{j} | z, x\right) & \propto \exp \left(w_{j}^{\top} \operatorname{BERT}_{\operatorname{MaSK}(j)}\left(\operatorname{join}_{\mathrm{BERT}}\left(x, z_{\mathrm{body}}\right)\right)\right) \end{aligned}

  • fine-tune:这里解决的Open-QA任务,因此预测 y y 的过程相当于预测答案对应的起始索引(start index)和终止索引(end index)。


    在这里插入图片描述

    计算 p ( y z , x ) p(y|z,x) 可表示为: p ( y z , x ) s S ( z , y ) exp ( M L P ( [ h S T A R T ( s ) ; h E N D ( s ) ] ) ) h S T A R T ( s ) = BERT S T A R T ( s ) ( Join B E R T ( x , z b o d y ) ) h E N D ( s ) = BERT E H D ( s ) ( join B E R T ( x , z b o d y ) ) \begin{aligned} p(y | z, x) & \propto \sum_{s \in S(z, y)} \exp \left(\mathrm{MLP}\left(\left[h_{\mathrm{START}(s)} ; h_{\mathrm{END}(\mathrm{s})}\right]\right)\right) \\ h_{\mathrm{START}(\mathrm{s})} &=\operatorname{BERT}_{\mathrm{START}(\mathrm{s})}\left(\operatorname{Join}_{\mathrm{BERT}}\left(x, z_{\mathrm{body}}\right)\right) \\ h_{\mathrm{END}(\mathrm{s})} &=\operatorname{BERT}_{\mathrm{EHD}(\mathrm{s})}\left(\operatorname{join}_{\mathrm{BERT}}\left(x, z_{\mathrm{body}}\right)\right) \end{aligned}

另外由于训练所使用的文档集规模巨大,在检索阶段的计算量就会特别的大。为了使模型能够顺利的跑起来,作者使用了一些机制来保证:

  • 预先计算好预料库中所有文档的嵌入向量以及构建有效的索引机制
  • 使用表示向量内积来对检索到的文档进行排序
  • 在训练几百步后再异步的更新表示向量所对应的索引

其中索引更新过程又分为两步进行,首先进行参数的更新,然后进行索引的构建。


在这里插入图片描述

为了减少模型训练过程中带来的偏差,作者同样提出了几种策略来进行处理,包括:

  • Salient span masking:即采用和ERNIE类似的方式,通过MASK比较重要的部分来迫使模型学习到更重要的东西

    在这里插入图片描述
  • Null document:当预测的部分不重要时,将检索结果设置为空文档 ϕ \phi 来减少计算量
  • Prohibiting trivial retrievals:避免检索结果和完整文档集一致
  • Initialization:采用更好的初始化方式来期望得到更好的表示向量

实验部分通过在不同的Open-QA数据集上进行实验和已有的预训练模型进行比较证明了REALM的优异性。


在这里插入图片描述

发布了295 篇原创文章 · 获赞 103 · 访问量 20万+

猜你喜欢

转载自blog.csdn.net/Forlogen/article/details/104343229