scheduled sampling

当前image  caption 存在的四个主要问题:
1、指标的提升

2、暴露偏差的累积。这个是指预测的时候,前面预测的结果是错的,后面的错得越来越离谱。

3、损失函数和评级指标没有直接挂钩。

4、只适用于配对的图像和文本。

这篇文章主要用于介绍针对第二个问题的一种解决方法。

序列生成任务的生成目标是在给定源输入的条件下,最大化目标序列的概率。训练时该模型将目标序列中的真实元素作为解码器每一步的输入,然后最大化下一个元素的概率。生成时上一步解码得到的元素被用作当前的输入,然后生成下一个元素。可见这种情况下训练阶段和生成阶段的解码器输入数据的概率分布并不一致。

Scheduled Sampling [1]是一种解决训练和生成时输入数据分布不一致的方法。在训练早期该方法主要使用目标序列中的真实元素作为解码器输入,可以将模型从随机初始化的状态快速引导至一个合理的状态。随着训练的进行,该方法会逐渐更多地使用生成的元素作为解码器输入,以解决数据分布不一致的问题。

标准的序列到序列模型中,如果序列前面生成了错误的元素,后面的输入状态将会收到影响,而该误差会随着生成过程不断向后累积。Scheduled Sampling以一定概率将生成的元素作为解码器输入,这样即使前面生成错误,其训练目标仍然是最大化真实目标序列的概率,模型会朝着正确的方向进行训练。因此这种方式增加了模型的容错能力

|2. 算法简介

Scheduled Sampling主要应用在序列到序列模型的训练阶段,而生成阶段则不需要使用。

训练阶段解码器在最大化第t个元素概率时,标准序列到序列模型使用上一时刻的真实元素yt−1作为输入。设上一时刻生成的元素为gt−1,Scheduled Sampling算法会以一定概率使用gt−1作为解码器输入。

设当前已经训练到了第i个mini-batch,Scheduled Sampling定义了一个概率ϵi控制解码器的输入。ϵi是一个随着i增大而衰减的变量,常见的定义方式有:

  • 线性衰减:ϵi=max(ϵ,k−c∗i),其中ϵ限制ϵi的最小值,k和c控制线性衰减的幅度。
  • 指数衰减:ϵi=ki,其中0<k<1,k控制着指数衰减的幅度。
  • 反向Sigmoid衰减:ϵi=k/(k+exp(i/k)),其中k>1,k同样控制衰减的幅度。

图1给出了这三种方式的衰减曲线,

图1. 线性衰减、指数衰减和

反向Sigmoid衰减的衰减曲线

如图2所示,在解码器的t时刻Scheduled Sampling以概率ϵi使用上一时刻的真实元素yt−1作为解码器输入,以概率1−ϵi使用上一时刻生成的元素gt−1作为解码器输入。从图1可知随着i的增大ϵi会不断减小,解码器将不断倾向于使用生成的元素作为输入,训练阶段和生成阶段的数据分布将变得越来越一致。

图2. Scheduled Sampling选择不同元素作为解码器输入示意图

猜你喜欢

转载自blog.csdn.net/zlrai5895/article/details/84748749