长短期记忆(LSTM)系列_LSTM的数据准备(5)——如何配置Keras中截断反向传播预测的输入序列步长

导读:

这篇文章是介绍了BPTT的概念,说明了数据截断的原因和方法,即提高网络的学习效率。以及如何找到最好的截断方法,即利用网格搜索。

文中都是一些概念介绍,这里直接把原文贴上来了。

原文链接:https://machinelearningmastery.com/truncated-backpropagation-through-time-in-keras/

正文翻译如下:

递归神经网络能够在序列预测问题中学习跨越多个时间步的时间依赖性。

像长短期记忆或LSTM网络这样的现代递归神经网络通过称为反向传播时间的反向传播算法的变体进行训练。为了提高具有很长序列的序列预测问题的效率,该算法已被进一步修改,并且被称为截断反向传播。

使用Truncated Backpropagation Through Time训练像LSTM这样的递归神经网络时,一个重要的配置参数决定了使用多少次步进作为输入。也就是说,如何将非常长的输入序列分成子序列以获得最佳性能。

在这篇文章中,您将发现6种不同的方法,您可以使用Keras在Python中使用Truncated Backpropagation Through Time来分割非常长的输入序列以有效地训练递归神经网络。

阅读这篇文章后,你会知道:

  • 通过时间截断反向传播是什么以及如何在Python深度学习库Keras中实现它。
  • 输入时间步数的选择究竟如何影响递归神经网络中的学习。
  • 您可以使用6种不同的技术来分割非常长的序列预测问题,以充分利用截断反向传播时间训练算法。

让我们开始吧。

 

通过时间截断反向传播

反向传播是用于更新神经网络中的权重的训练算法,以便最小化预期输出和给定输入的预测输出之间的误差。

对于序列预测问题,其中观察之间存在顺序依赖性,使用递归神经网络代替经典前馈神经网络。使用Backpropagation算法的变体训练递归神经网络,该算法称为Backpropagation Through Time,简称BPTT。

实际上,BPTT展开递归神经网络并在整个输入序列上向后传播误差,一次一个步骤。然后用累积的梯度更新权重。

对于输入序列很长的问题,BPTT可能很难训练复现神经网络。除了速度之外,在很多时间步长上累积梯度可能导致值缩小到零,或者最终溢出或爆炸的值增长。

BPTT的修改是限制在向后传递上使用的时间步数,并且实际上估计用于更新权重的梯度而不是完全计算它。

这种变化称为截断反向传播时间或TBPTT。

TBPTT训练算法有两个参数:

  • k1:定义正向传递上显示给网络的时间步数。
  • k2:定义在向后传递上估计渐变时要查看的时间步数。

因此,当考虑如何配置训练算法时,我们可以使用符号TBPTT(k1,k2),其中k1 = k2 = n,其中n是经典非截断BPTT的输入序列长度。

TBPTT配置对RNN序列模型的影响

像LSTM这样的现代递归神经网络可以使用它们的内部状态来记住非常长的输入序列。如超过数千次的步骤。

这意味着TBPTT的配置不一定通过选择时间步长来定义您正在优化的网络内存。您可以选择何时将网络的内部状态与用于更新网络权重的机制分开重置。

相反,TBPTT参数的选择会影响网络如何估计用于更新权重的误差梯度。更一般地,配置定义了可以考虑网络来对序列问题建模的时间步数。

我们可以正式说明这样的事情:

yhat(t) = f(X(t), X(t-1), X(t-2), ... X(t-n))

它在概念上与在时间序列问题上训练的多层感知器上的窗口大小相似(但在实践中完全不同),或者与ARIMA等线性时间序列模型的p和q参数相似。TBPTT定义了训练期间模型输入序列的范围。如果yhat是特定时间步的输出,则f(...)是递归神经网络近似的关系,X(t)是特定时间步的观测值。

 

Keras实施TBPTT

Keras深度学习库提供TBPTT的实现,用于训练复现神经网络。

实施比上面列出的一般版本更受限制。

具体地,k1和k2值彼此相等并固定。

  • TBPTT(k1,k2),其中k1 = k2

这是通过训练诸如长短期记忆网络或LSTM之类的递归神经网络所需的固定大小的三维输入来实现的。

LSTM期望输入数据具有尺寸:样本,时间步和特征。

它是此输入格式的第二个维度,即时间步长,用于定义用于序列预测问题的前向和后向传递的时间步数。

因此,在为Keras中的序列预测问题准备输入数据时,必须仔细选择指定的时间步数。

时间步的选择将影响两者:

  • 在前进过程中积累的内部状态。
  • 梯度估计用于更新后向传递的权重。

请注意,默认情况下,每次批处理后都会重置网络的内部状态,但可以通过使用所谓的有状态LSTM并手动调用重置操作来实现对内部状态重置的更明确控制。

有关Keras中有状态LSTM的更多信息,请参阅帖子:

在Keras中准备TBPTT的序列数据

分解序列数据的方式将定义BPTT前向和后向传递中使用的时间步数。

因此,您必须仔细考虑如何准备训练数据。

本节列出了您可以考虑的6种技术。

1.使用数据原样

如果每个序列中的时间步数是适度的,例如几十或几百步,则可以按原样使用输入序列。

已经提出TBPTT的实际限制为约200至400倍步长。

如果序列数据小于或等于此范围,则可以将序列观察值重新整形为输入数据的时间步长。

例如,如果您有一个包含25个时间步长的100个单变量序列的集合,则可以将其重新整形为100个样本,25个时间步长和1个特征或[100,25,1]。

2.朴素的数据拆分

如果您有很长的输入序列,例如数千个时间步长,则可能需要将长输入序列分成多个连续的子序列。

这将需要在Keras中使用有状态LSTM,以便在子序列的输入上保持内部状态,并且仅在真正更充分的输入序列的末尾处重置。

例如,如果您有100个输入序列的50,000个步骤,则每个输入序列可以分为100个子步骤,500个步骤。一个输入序列将变为100个样本,因此100个原始样本将变为10,000。Keras输入的维数为10,000个样本,500个步骤和1个特征或[10000,500,1]。需要注意保持每100个子序列的状态,并在每100个样本之后明确地或通过使用100的批量大小重置内部状态。

将整个序列整齐地划分为固定大小的子序列的划分是优选的。全序列因子(子序列长度)的选择是任意的,因此名称为“天真数据分裂”。

将序列分成子序列不考虑关于用于估计用于更新权重的误差梯度的合适数量的时间步的域信息。

3.特定于域的数据拆分

可能很难知道提供错误梯度的有用估计所需的正确时间步数。

我们可以使用天真的方法(上面)快速获得模型,但模型可能远未优化。

或者,我们可以使用特定于域的信息来估计在学习问题时与模型相关的时间步数。

例如,如果序列问题是回归时间序列,则可能对自相关和部分自相关图的检查可以通知选择时间步数。

如果序列问题是自然语言处理问题,则输入序列可以按句子划分,然后填充到固定长度,或者根据域中的平均句子长度进行划分。

广泛思考并考虑您可以使用哪些特定于您的域的知识将序列拆分为有意义的块。

4.系统数据拆分(例如网格搜索)

您可以系统地为序列预测问题评估一组不同的子序列长度,而不是猜测适当数量的时间步长。

您可以对每个子序列长度执行网格搜索,并采用导致平均性能最佳的模型的配置。

如果您正在考虑这种方法,请注意一些注意事项:

  • 从作为整个序列长度因子的子序列长度开始。
  • 如果探索不是整个序列长度因子的子序列长度,则使用填充和掩蔽。
  • 考虑使用略微过度规定的网络(更多的存储单元和更多的训练时期)来解决问题,以帮助排除网络容量作为实验的限制。
  • 获取每个不同配置的多次运行(例如30)的平均性能。

如果计算资源不是限制,则建议对不同时间步数进行系统调查。

5.使用TBPTT严重依靠内部状态(1,1)

您可以将序列预测问题重新表述为每个时间步一个输入和一个输出。

例如,如果您有100个50次步长的序列,则每个时间步长将成为新的样本。100个样本将变为5,000。三维输入将变为5,000个样本,1个步骤和1个特征,或[5000,1,1]。

同样,这将要求在序列的每个时间步长内保留内部状态,并在每个实际序列的末尾重置(50个样本)。

这将把学习序列预测问题的负担放在递归神经网络的内部状态上。根据问题的类型,它可能不仅仅是网络可以处理的,而且预测问题可能无法学习。

个人经验表明,这种表述可能适用于需要对序列进行记忆的预测问题,但是当结果是过去观察的复杂函数时表现不佳。

6.解耦前向和后向序列长度

Keras深度学习库用于支持通过时间截断反向传播的前向和后向传递的解耦的时间步长数。

本质上,k1参数可以通过输入序列上的时间步数来指定,并且k2参数可以通过LSTM层上的“truncate_gradient”参数来指定。

这不再受支持,但有一些愿望将此功能重新添加到库中。虽然有证据表明它是出于效率原因完成的,但目前还不清楚为何被删除

你可以在Keras探索这种方法。一些想法包括:

  • 安装并使用支持“truncate_gradient”参数的旧版Keras库(大约2015年)。
  • 在Keras中扩展LSTM层实现以支持“truncate_gradient”类型行为。

也许有可用于Keras的第三方扩展支持此行为。

猜你喜欢

转载自blog.csdn.net/yangwohenmai1/article/details/84823341