LSTM計算プロセスの控除
(1)LSTMの概要
LSTMは、一般的に使用されているリカレントニューラルネットワーク(つまり、LSTMcellを使用したリカレントネットワーク)であり、非常に便利です。ここでは、LSTMの計算プロセスを紙とコードと組み合わせて紹介します。
(2)LSTMノート
最初に写真を入れてください(ダニエルのメモから取られました)
(1)Ct-1-> Ct(一番上の水平線、セルの状態)は小さな線形変換しか受けていないので、情報が逆方向に送信されるのに便利です、それが長期シーケンスの理由を記憶することができます。
(2)状態変化はゲート情報によって制御されます。図から、乗算(ゲートを忘れる-元の状態が残っていることを示します)と加算(入力ゲート-新しい情報の量を示します)を見ることができます。の状態に追加されました)。
(3)忘却ベクトルの計算(Ht-1は最後のタイムスタンプの隠れ層であり、sigmodは0-1を出力します)
(4)新しい状態情報の計算(これは入力ゲートベクトルであり、離脱を示します)と新しい情報の残し)
(5)出力情報の計算(Otは出力ゲートベクトル情報です)
セルの状態をtanhに入れます(値を-1から1の間にプッシュするため)
(3)LSTMバリアント
(1)状態を覗くことができるLSTM
(2)入力ゲートと忘却ゲートを接続します
(3)ゲート付き回帰ユニット、またはGRUには状態がなく、非表示になっているだけです。
(4)LSTMコード
上記の原理から、3つの制御ゲートは学習するために3つの変数に加えて、保存された状態、合計4つの変数を学習する必要があることがわかります。
上記の式[ht-1、xt]はデータの2つの部分を使用するため、2つのカーネルが定義され、カーネルはxtのマッピングに使用され、recurrent_kernelはht-1の処理に使用されます。
最後に、ドア情報i、f、oがそれぞれ入力、忘却、出力されます。そして、ゲート情報と[ht-1、xt]から、メモリ状態であるセル状態cを取得します。
最後に、出力ゲートを使用して、セルh = o * self.activation©の出力を取得します。
h_tm1 = states[0]
c_tm1 = states[1]
if self.implementation == 1:
if 0 < self.dropout < 1.:
inputs_i = inputs * dp_mask[0]
inputs_f = inputs * dp_mask[1]
inputs_c = inputs * dp_mask[2]
inputs_o = inputs * dp_mask[3]
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)
if self.use_bias:
x_i = K.bias_add(x_i, self.bias_i)
x_f = K.bias_add(x_f, self.bias_f)
x_c = K.bias_add(x_c, self.bias_c)
x_o = K.bias_add(x_o, self.bias_o)
if 0 < self.recurrent_dropout < 1.:
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o,self.recurrent_kernel_o))