リカレント ニューラル ネットワークを始めるための基本

記事とコードは [Github リポジトリ: https://github.com/timerring/dive-into-AI ] またはパブリック アカウント [AIShareLab]にアーカイブされています。Reply Neural Network Ba​​sicsも入手できます。

リカレントニューラルネットワーク

シーケンスデータ

シーケンスデータは一般的なデータタイプであり、通常、前後のデータは関連しています

例「猫の睡眠時間は1日平均15時間」

言語モデル

言語モデルは、自然言語処理 (NLP、自然言語処理) の重要なテクノロジーです。

NLP では、テキストは離散時系列とみなされます。長さ T のテキストの単語は、W 1 、 W 2 、…、 WT \mathrm{W}_{1}、 \mathrm{~W}_{ となります。 2}、\ldots、\mathrm{W}_{\mathrm{T}}W1 W2WT, その中wt ( 1 ≤ t ≤ T ) \mathrm{w}_{\mathrm{t}}(1 \leq \mathrm{t} \leq \mathrm{T})w( 1tT )は **時間ステップ**t の出力またはラベルです。

言語モデルは、系列確率P ( w 1 , w 2 , … , w T ) \mathrm{P}\left(\mathrm{w}_{1}, \mathrm{w}_{2}, \ を計算します。 mathrm{w}_{2}、\ ldots、\mathrm{w}_{\mathrm{T}}\right)P( w1w2wT) 、猫の睡眠時間は 1 日平均15などです

言語モデルは系列確率を計算します:
P ( w 1 , w 2 , … , w T ) = ∏ t = 1 TP ( wt ∣ w 1 , … , wt − 1 ) \mathrm{P}\left(\mathrm{ w} _{1}, \mathrm{w}_{2}, \ldots, \mathrm{w}_{\mathrm{T}}\right)=\prod_{\mathrm{t}=1}^{ \mathrm {T}} \mathrm{P}\left(\mathrm{w}_{\mathrm{t}} \mid \mathrm{w}_{1}, \ldots, \mathrm{w}_{\ mathrm{ t}-1}\right)P( w1w2wT)=t = 1TP( ww1wt 1)

例: P(私、アット、聞く、クラス) = P(私) * P(アット | 私) * P (聞く | 私、アット) * P (レッスン | 私、アット、聞く)

**コーパス (Corpus)** 内の単語の頻度を数え、上記の確率を取得し、最終的に P(I, in, listens, class) を取得します。

短所: タイムステップ t の単語はステップ t -1 の単語を考慮する必要があり、計算量は t とともに指数関数的に増加します。

RNN - リカレント ニューラル ネットワーク

RNN はシーケンス データ用に開発されたニューラル ネットワーク構造です。その核心は、タイム ステップの増加によるパラメータの急増を回避するためにネットワーク層パラメータを再利用し、履歴情報を記録するために **Hidden State** を導入することで効果的です。データのコンテキスト性に対処します。

**Hidden State (Hidden State)** は、履歴情報を記録し、データのコンテキストを効果的に処理するために使用されます。 **アクティベーション関数は、Tanh を使用して出力値の範囲を (-1, 1) に制限し、値が不正確になるのを防ぎます。指数関数的に変化します。**次の比較が可能です。

RNN は言語モデルを構築し、テキスト生成を実現します。「think」、「want」、「have」、「straight」、「liter」、「machine」というテキスト シーケンスを想定します。

RNN の特徴:

  1. リカレント ニューラル ネットワークの隠れ状態は、現在のタイム ステップまでのシーケンスの履歴情報をキャプチャできます。
  2. RNN モデルのパラメーターの数は時間ステップとともに増加しません。(いつもW hh W_{hh}WふーんW xh W_{xh}W×時間W hq W_{hq}Wああ_

H t = ϕ ( X t W xh + H t − 1 W hh + bh ) O t = H t W hq + bq \begin{aligned} \boldsymbol{H}_{t} & =\phi\left(\ボールドシンボル{X}_{t} \boldsymbol{W}_{xh}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{hh}+\boldsymbol{b}_{h}\right ) \\ \boldsymbol{O}_{t} & =\boldsymbol{H}_{t} \boldsymbol{W}_{hq}+\boldsymbol{b}_{q} \end{aligned}H=ϕ( XW×時間+Ht 1Wふーん+b)=HWああ_+bq

RNN の時間によるバックプロパゲーション(時間によるバックプロパゲーション)

いくつかの経路があり、いくつかのアイテムが追加されます。

便宜上、これらを式 1 ~ 4 と呼びます。

∂ L ∂ W qh = ∑ t = 1 T prod ⁡ ( ∂ L ∂ ot , ∂ ot ∂ W qh ) = ∑ t = 1 T ∂ L ∂ otht ⊤ ∂ L ∂ h T = prod ⁡ ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W qh ⊤ ∂ L ∂ o T ∂ L ∂ ht = prod ⁡ ( ∂ L ∂ ht + 1 , ∂ ht + 1 ∂ ht ) + prod ⁡ ( ∂ L ∂ ot , ∂ ot ∂ ht ) = W hh ⊤ ∂ L ∂ ht + 1 + W qh ⊤ ∂ L ∂ ot ∂ L ∂ ht = ∑ i = t T ( W hh ⊤ ) T − i W qh ⊤ ∂ L ∂ o T + t − i \begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{qh}} & =\sum_{t=1}^{T} \operatorname{prod}\left(\frac{ \partial L}{\partial \boldsymbol{o}_{t}}, \frac{\partial \boldsymbol{o}_{t}}{\partial \boldsymbol{W}_{qh}}\right)= \sum_{t=1}^{T} \frac{\partial L}{\partial \boldsymbol{o}_{t}} \boldsymbol{h}_{t}^{\top} \\ \frac{ \partial L}{\partial \boldsymbol{h}_{T}} & =\operatorname{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_{T}},\frac{\partial \boldsymbol{o}_{T}}{\partial \boldsymbol{h}_{T}}\right)=\boldsymbol{W}_{qh}^{\top} \frac{\部分 L}{\partial \boldsymbol{o}_{T}} \\ \frac{\partial L}{\partial \boldsymbol{h}_{t}} & =\operatorname{prod}\left(\frac {\partial L}{\partial \boldsymbol{h}_{t+1}}、\frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_{t} }\right)+\operatorname{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_{t}},\frac{\partial \boldsymbol{o}_{t}}{\partial \boldsymbol{h}_{t}}\right) \\ & =\boldsymbol{W}_{hh}^{\top} \ frac{\partial L}{\partial \boldsymbol{h}_{t+1}}+\boldsymbol{W}_{qh}^{\top} \frac{\partial L}{\partial \boldsymbol{o }_{t}} \\ \frac{\partial L}{\partial \boldsymbol{h}_{t}} & =\sum_{i=t}^{T}\left(\boldsymbol{W}_ {hh}^{\top}\right)^{Ti} \boldsymbol{W}_{qh}^{\top} \frac{\partial L}{\partial \boldsymbol{o}_{T+ti} \end{整列}∂W _qh _∂L _∂h _T∂L _∂h _∂L _∂h _∂L _=t = 1Tプロッド(∂L _∂W _qh _)=t = 1T∂L _ht=プロッド(T∂L _∂h _TT)=Wああ_T∂L _=プロッド(∂h _t + 1∂L _∂h _∂h _t + 1)+プロッド(∂L _∂h _)=Wふーん∂h _t + 1∂L _+Wああ_∂L _=i = tT( Wふーん)T iWああ_T + t i∂L _

上と同様、T=3。LL は2 番目の式から計算できます。Lh T h_ThT部分的なガイド。

次に、 3 番目の式でLLを計算します。Lht h_thの偏導関数、LLに注意してください。Lh T h_ThTの偏導関数は計算されており、直接取り込むことができます。

次に、類推して、LL を取得します。Lht h_th偏導関数の一般式については、第 4 式を参照してください。

ここでは、4 番目の一般式を使用してLLを計算できます。L はh 1 h_1の場合h1の偏微分は次のようになります。

残りの 2 つのパラメータの計算は、パスが多すぎるため、ここでパスを計算するのは比較的複雑ですが、問題を分解するには逆伝播の考え方を使用するだけで済みます。しかし、結果は依然として非常に複雑です。以下に示すように、W hx W_{hx}を計算します。Wh ×ht h_tの勾配が使用されますhの偏導関数を上記と同時に求め、ht h_tを求めます。h偏導関数にはht + 1 h_{t+1}も含まれます。ht + 1偏微分、再帰的...勾配の計算は時間を超えて進みます。

便宜上、これらを式 5 ~ 6 と呼びます。

∂ L ∂ W hx = ∑ t = 1 T prod ⁡ ( ∂ L ∂ ht , ∂ ht ∂ W hx ) = ∑ t = 1 T ∂ L ∂ htxt ⊤ , ∂ L ∂ W hh = ∑ t = 1 T prod ⁡ ( ∂ L ∂ ht , ∂ ht ∂ W hh ) = ∑ t = 1 T ∂ L ∂ htht − 1 ⊤ 。\begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} & =\sum_{t=1}^{T} \operatorname{prod}\left(\frac{\部分 L}{\partial \boldsymbol{h}_{t}}, \frac{\partial \boldsymbol{h}_{t}}{\partial \boldsymbol{W}_{hx}}\right)=\ sum_{t=1}^{T} \frac{\partial L}{\partial \boldsymbol{h}_{t}} \boldsymbol{x}_{t}^{\top}, \\ \frac{ \partial L}{\partial \boldsymbol{W}_{hh}} & =\sum_{t=1}^{T} \operatorname{prod}\left(\frac{\partial L}{\partial \boldsymbol {h}_{t}}、\frac{\partial \boldsymbol{h}_{t}}{\partial \boldsymbol{W}_{hh}}\right)=\sum_{t=1}^{ T} \frac{\partial L}{\partial \boldsymbol{h}_{t}} \boldsymbol{h}_{t-1}^{\top} 。\end{整列}∂W _h ×∂L _∂W _ふーん∂L _=t = 1Tプロッド(∂h _∂L _∂W _h ×∂h _)=t = 1T∂h _∂L _バツt=t = 1Tプロッド(∂h _∂L _∂W _ふーん∂h _)=t = 1T∂h _∂L _ht 1.

したがって、勾配は時間 t とともに指数関数的に変化し、勾配の消失勾配の爆発が起こりやすいという問題があります(例: W hh W_{hh}Wふーん、式 4 を参照W hh ⊤ {W}_{hh}^{\top}Wふーん電源の問題が関係している場合、W hh W_{hh}Wふーん< 1 は最終的に 0 になる傾向があり、勾配が消失します。W hh W_{hh}の場合Wふーん> 1 は最終的に無限大に達し、勾配爆発を引き起こします)。

GRU - ゲート付きリカレント ユニット

RNN 勾配の消失によって引き起こされる問題を軽減するためにゲートのリカレント ネットワークが導入され、情報の流れを制御するためにゲートの概念が導入され、モデルが長期情報をよりよく記憶し、勾配の消失を軽減できるようになります。 。

  • ゲートのリセット: 忘れるべき情報とは
  • Update Gate : 注意が必要な情報

活性化関数はシグモイドで、値の範囲は (0, 1) です。0 は忘れることを意味し、1 は保持を意味します。ゲートの値の範囲が (0, 1) の間にあることがわかります。

リセット ゲートは、候補隠れ状態の計算プロセス中に、前のタイム ステップの隠れ状態のどの情報を忘れるかに使用されます。

更新ゲートの機能は、現在のタイム ステップの隠れ状態を更新するときに、前のタイム ステップH t − 1 \boldsymbol{H}_{t-1} の隠れ状態を結合することです。Ht 1現在のタイムステップの候補隠れ状態H ~ t \tilde{\boldsymbol{H}}_{\mathrm{t}}H最終的なH t \boldsymbol{H}_{t}を取得しますHR t = σ ( X t W xr + H t − 1 W hr + br ) Z t = σ ( X t W xz + H t − 1 W hz + bz ) \begin{aligned} \boldsymbol{R} _
{t} & =\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{xr}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{hr}+ \boldsymbol{b}_{r}\right) \\ \boldsymbol{Z}_{t} & =\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{xz}+\ボールドシンボル{H}_{t-1} \boldsymbol{W}_{hz}+\boldsymbol{b}_{z}\right) \end{aligned}RZ=p( XW× r+Ht 1W時間_+br)=p( XWxz _+Ht 1WHz _+bz)
隠れた状態の候補

入力は前のタイム ステップの隠れ状態と共同計算されて、隠れ状態の計算に使用される候補隠れ状態が取得されます。ゲートをリセットすると、前のタイム ステップの隠れた状態が選択的に忘れられ、履歴情報がより適切に選択されます。

GRU:
H ~ t = Tanh ⁡ ( X t W xh + ( R t ⊙ H t − 1 ) W hh + bh ) \tilde{\boldsymbol{H}}_{\mathrm{t}}=\tanh \left (\boldsymbol{X}_{\mathrm{t}} \boldsymbol{W}_{\mathrm{xh}}+\left(\boldsymbol{R}_{\mathrm{t}} \odot \boldsymbol{H }_{\mathrm{t}-1}\right) \boldsymbol{W}_{\mathrm{hh}}+\boldsymbol{b}_{\mathrm{h}}\right)H=胡散臭い( XWxh+( RHt 1)Wふーん+b)

元の RNN と比較してみます。
H t = ϕ ( X t W xh + H t − 1 W hh + bh ) \boldsymbol{H}_{\mathrm{t}}=\phi\left(\boldsymbol{X }_ {\mathrm{t}} \boldsymbol{W}_{\mathrm{xh}}+\boldsymbol{H}_{\mathrm{t}-1} \boldsymbol{W}_{\mathrm{hh} }+ \boldsymbol{b}_{\mathrm{h}}\right)H=ϕ( XWxh+Ht 1Wふーん+b)

隠れた状態

隠れ状態は、候補の隠れ状態前のタイム ステップの隠れ状態を組み合わせることによって取得されます。
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \boldsymbol{H}_{\mathrm{t }}= \boldsymbol{Z}_{\mathrm{t}} \odot \boldsymbol{H}_{\mathrm{t}-1}+\left(1-\boldsymbol{Z}_{\mathrm{t }}\ 右) \odot \tilde{\boldsymbol{H}}_{\mathrm{t}}H=ZHt 1+( 1Z)H

GRU の特徴:

ゲート機構はシグモイド活性化関数を採用しており、ゲート値は(0,1)となり、0は忘却、1は保持を意味します。

最初のタイム ステップから t-1 時間まで更新ゲートが 1 に保たれる場合、情報は現在のタイム ステップに効果的に送信されます
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \boldsymbol{H}_{\mathrm{t}}=\boldsymbol{Z}_{\mathrm{t}} \odot \boldsymbol{H}_{\mathrm{t}-1}+\left(1-\boldsymbol{Z}_{\mathrm{t}}\right) \odot \tilde{\boldsymbol{H}}_{ \mathrm{t}}H=ZHt 1+( 1Z)H
リセット ゲート: 最後のタイム ステップの隠れた状態を忘れるために使用されます
H ~ t = Tanh ⁡ ( X t W xh + ( R t ⊙ H t − 1 ) W hh + bh ) \tilde{\boldsymbol{H}}_ { \mathrm{t}}=\tanh \left(\boldsymbol{X}_{\mathrm{t}} \boldsymbol{W}_{\mathrm{xh}}+\left(\boldsymbol{R}_{ \ mathrm{t}} \odot \boldsymbol{H}_{\mathrm{t}-1}\right) \boldsymbol{W}_{\mathrm{hh}}+\boldsymbol{b}_{\mathrm{ h }}\右)H=胡散臭い( XWxh+( RHt 1)Wふーん+b)
更新ゲート: 現在のタイム ステップの隠れ状態を更新するために使用されます
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \boldsymbol{H}_{\mathrm{t}}=\boldsymbol { Z}_{\mathrm{t}} \odot \boldsymbol{H}_{\mathrm{t}-1}+\left(1-\boldsymbol{Z}_{\mathrm{t}}\right) \ ドット \チルダ{\boldsymbol{H}}_{\mathrm{t}}H=ZHt 1+( 1Z)H

LSTM — 長短期記憶ネットワーク

LSTM

情報伝達を制御する3つのゲートメモリセルを導入

  • 忘れられた門: 忘れるべき情報とは
  • 入力ゲート: 現在のメモリセルにどのような情報を流す必要があるか
  • 出力ゲート: どのメモリ情報が隠れ状態に流れ込むか
  • メモリセル: 特別な隠れた状態、履歴情報を記憶


I t = σ ( X t W xi + H t − 1 W hi + bi ) F t = σ ( X t W xf + H t − 1 W hf + bf ) O t = σ ( X t W xo + H t − 1 W ho + bo ) \begin{aligned} \boldsymbol{I}_{t} & =\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{xi}+\boldsymbol{ H}_{t-1} \boldsymbol{W}_{hi}+\boldsymbol{b}_{i}\right) \\ \boldsymbol{F}_{t} & =\sigma\left(\boldsymbol {X}_{t} \boldsymbol{W}_{xf}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{hf}+\boldsymbol{b}_{f}\right) \\ \boldsymbol{O}_{t} & =\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{xo}+\boldsymbol{H}_{t-1} \boldsymbol {W}_{ho}+\boldsymbol{b}_{o}\right) \end{aligned}F=p( XW×+Ht 1Wこんにちは+b私は)=p( XW× f+Ht 1Wふふ_+b)=p( XW×+Ht 1Wほら_+bああ)
候補メモリセル

メモリセル:過去の瞬間情報を保存する特別な隠れ状態として理解できます
C ~ t = Tanh ⁡ ( X t W xc + H t − 1 W hc + bc ) \tilde{\boldsymbol{C}}_{\mathrm {t }}=\tanh \left(\boldsymbol{X}_{\mathrm{t}} \boldsymbol{W}_{\mathrm{xc}}+\boldsymbol{H}_{\mathrm{t}- 1} \boldsymbol{W}_{\mathrm{hc}}+\boldsymbol{b}_{\mathrm{c}}\right)C=胡散臭い( XWxc+Ht 1WHC+bc)
メモリセルと隠れ状態

メモリ セルは、候補メモリ セル前のタイム ステップのメモリ セルを組み合わせて取得されます
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t \boldsymbol{C}_{t}=\boldsymbol{F}_{t } \odot \boldsymbol{C}_{t-1}+\boldsymbol{I}_{t} \odot \tilde{\boldsymbol{C}}_{\mathrm{t}}C=FCt 1+C
メモリ セルの情報は出力ゲートによって制御され、隠れ状態
に流れます。 H t = O t ⊙ Tanh ⁡ ( C t ) \boldsymbol{H}_{\mathrm{t}}=\boldsymbol{O}_{\ mathrm{t}} \ odot \tanh \left(\boldsymbol{C}_{\mathrm{t}}\right)H=胡散臭い( C)

要約する

おすすめ

転載: blog.csdn.net/m0_52316372/article/details/131473844