記事とコードは [Github リポジトリ: https://github.com/timerring/dive-into-AI ] またはパブリック アカウント [AIShareLab]にアーカイブされています。Reply Neural Network Basicsも入手できます。
記事ディレクトリ
リカレントニューラルネットワーク
シーケンスデータ
シーケンスデータは一般的なデータタイプであり、通常、前後のデータは関連しています
例「猫の睡眠時間は1日平均15時間」
言語モデル
言語モデルは、自然言語処理 (NLP、自然言語処理) の重要なテクノロジーです。
NLP では、テキストは離散時系列とみなされます。長さ T のテキストの単語は、W 1 、 W 2 、…、 WT \mathrm{W}_{1}、 \mathrm{~W}_{ となります。 2}、\ldots、\mathrm{W}_{\mathrm{T}}W1、 W2、…、WT, その中wt ( 1 ≤ t ≤ T ) \mathrm{w}_{\mathrm{t}}(1 \leq \mathrm{t} \leq \mathrm{T})wた( 1≤t≤T )は **時間ステップ**t の出力またはラベルです。
言語モデルは、系列確率P ( w 1 , w 2 , … , w T ) \mathrm{P}\left(\mathrm{w}_{1}, \mathrm{w}_{2}, \ を計算します。 mathrm{w}_{2}、\ ldots、\mathrm{w}_{\mathrm{T}}\right)P( w1、w2、…、wT) 、猫の睡眠時間は 1 日平均15などです
言語モデルは系列確率を計算します:
P ( w 1 , w 2 , … , w T ) = ∏ t = 1 TP ( wt ∣ w 1 , … , wt − 1 ) \mathrm{P}\left(\mathrm{ w} _{1}, \mathrm{w}_{2}, \ldots, \mathrm{w}_{\mathrm{T}}\right)=\prod_{\mathrm{t}=1}^{ \mathrm {T}} \mathrm{P}\left(\mathrm{w}_{\mathrm{t}} \mid \mathrm{w}_{1}, \ldots, \mathrm{w}_{\ mathrm{ t}-1}\right)P( w1、w2、…、wT)=t = 1∏TP( wた∣w1、…、wt − 1)
例: P(私、アット、聞く、クラス) = P(私) * P(アット | 私) * P (聞く | 私、アット) * P (レッスン | 私、アット、聞く)
**コーパス (Corpus)** 内の単語の頻度を数え、上記の確率を取得し、最終的に P(I, in, listens, class) を取得します。
短所: タイムステップ t の単語はステップ t -1 の単語を考慮する必要があり、計算量は t とともに指数関数的に増加します。
RNN - リカレント ニューラル ネットワーク
RNN はシーケンス データ用に開発されたニューラル ネットワーク構造です。その核心は、タイム ステップの増加によるパラメータの急増を回避するためにネットワーク層パラメータを再利用し、履歴情報を記録するために **Hidden State** を導入することで効果的です。データのコンテキスト性に対処します。
**Hidden State (Hidden State)** は、履歴情報を記録し、データのコンテキストを効果的に処理するために使用されます。 **アクティベーション関数は、Tanh を使用して出力値の範囲を (-1, 1) に制限し、値が不正確になるのを防ぎます。指数関数的に変化します。**次の比較が可能です。
RNN は言語モデルを構築し、テキスト生成を実現します。「think」、「want」、「have」、「straight」、「liter」、「machine」というテキスト シーケンスを想定します。
RNN の特徴:
- リカレント ニューラル ネットワークの隠れ状態は、現在のタイム ステップまでのシーケンスの履歴情報をキャプチャできます。
- RNN モデルのパラメーターの数は時間ステップとともに増加しません。(いつもW hh W_{hh}Wふーん、W xh W_{xh}W×時間、W hq W_{hq}Wああ_)
H t = ϕ ( X t W xh + H t − 1 W hh + bh ) O t = H t W hq + bq \begin{aligned} \boldsymbol{H}_{t} & =\phi\left(\ボールドシンボル{X}_{t} \boldsymbol{W}_{xh}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{hh}+\boldsymbol{b}_{h}\right ) \\ \boldsymbol{O}_{t} & =\boldsymbol{H}_{t} \boldsymbol{W}_{hq}+\boldsymbol{b}_{q} \end{aligned}Hた○た=ϕ( XたW×時間+Ht − 1Wふーん+bふ)=HたWああ_+bq
RNN の時間によるバックプロパゲーション(時間によるバックプロパゲーション)
いくつかの経路があり、いくつかのアイテムが追加されます。
便宜上、これらを式 1 ~ 4 と呼びます。
∂ L ∂ W qh = ∑ t = 1 T prod ( ∂ L ∂ ot , ∂ ot ∂ W qh ) = ∑ t = 1 T ∂ L ∂ otht ⊤ ∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W qh ⊤ ∂ L ∂ o T ∂ L ∂ ht = prod ( ∂ L ∂ ht + 1 , ∂ ht + 1 ∂ ht ) + prod ( ∂ L ∂ ot , ∂ ot ∂ ht ) = W hh ⊤ ∂ L ∂ ht + 1 + W qh ⊤ ∂ L ∂ ot ∂ L ∂ ht = ∑ i = t T ( W hh ⊤ ) T − i W qh ⊤ ∂ L ∂ o T + t − i \begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{qh}} & =\sum_{t=1}^{T} \operatorname{prod}\left(\frac{ \partial L}{\partial \boldsymbol{o}_{t}}, \frac{\partial \boldsymbol{o}_{t}}{\partial \boldsymbol{W}_{qh}}\right)= \sum_{t=1}^{T} \frac{\partial L}{\partial \boldsymbol{o}_{t}} \boldsymbol{h}_{t}^{\top} \\ \frac{ \partial L}{\partial \boldsymbol{h}_{T}} & =\operatorname{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_{T}},\frac{\partial \boldsymbol{o}_{T}}{\partial \boldsymbol{h}_{T}}\right)=\boldsymbol{W}_{qh}^{\top} \frac{\部分 L}{\partial \boldsymbol{o}_{T}} \\ \frac{\partial L}{\partial \boldsymbol{h}_{t}} & =\operatorname{prod}\left(\frac {\partial L}{\partial \boldsymbol{h}_{t+1}}、\frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_{t} }\right)+\operatorname{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_{t}},\frac{\partial \boldsymbol{o}_{t}}{\partial \boldsymbol{h}_{t}}\right) \\ & =\boldsymbol{W}_{hh}^{\top} \ frac{\partial L}{\partial \boldsymbol{h}_{t+1}}+\boldsymbol{W}_{qh}^{\top} \frac{\partial L}{\partial \boldsymbol{o }_{t}} \\ \frac{\partial L}{\partial \boldsymbol{h}_{t}} & =\sum_{i=t}^{T}\left(\boldsymbol{W}_ {hh}^{\top}\right)^{Ti} \boldsymbol{W}_{qh}^{\top} \frac{\partial L}{\partial \boldsymbol{o}_{T+ti} \end{整列}∂W _qh _∂L _∂h _T∂L _∂h _た∂L _∂h _た∂L _=t = 1∑Tプロッド(∂おた∂L _、∂W _qh _∂おた)=t = 1∑T∂おた∂L _ht⊤=プロッド(∂おT∂L _、∂h _T∂おT)=Wああ_⊤∂おT∂L _=プロッド(∂h _t + 1∂L _、∂h _た∂h _t + 1)+プロッド(∂おた∂L _、∂h _た∂おた)=Wふーん⊤∂h _t + 1∂L _+Wああ_⊤∂おた∂L _=i = t∑T( Wふーん⊤)T − iWああ_⊤∂おT + t − i∂L _
上と同様、T=3。LL は2 番目の式から計算できます。Lはh T h_ThT部分的なガイド。
次に、 3 番目の式でLLを計算します。Lはht h_thたの偏導関数、LLに注意してください。Lはh T h_ThTの偏導関数は計算されており、直接取り込むことができます。
次に、類推して、LL を取得します。Lはht h_thた偏導関数の一般式については、第 4 式を参照してください。
ここでは、4 番目の一般式を使用してLLを計算できます。L はh 1 h_1の場合h1の偏微分は次のようになります。
残りの 2 つのパラメータの計算は、パスが多すぎるため、ここでパスを計算するのは比較的複雑ですが、問題を分解するには逆伝播の考え方を使用するだけで済みます。しかし、結果は依然として非常に複雑です。以下に示すように、W hx W_{hx}を計算します。Wh ×ht h_tの勾配が使用されますhたの偏導関数を上記と同時に求め、ht h_tを求めます。hた偏導関数にはht + 1 h_{t+1}も含まれます。ht + 1偏微分、再帰的...勾配の計算は時間を超えて進みます。
便宜上、これらを式 5 ~ 6 と呼びます。
∂ L ∂ W hx = ∑ t = 1 T prod ( ∂ L ∂ ht , ∂ ht ∂ W hx ) = ∑ t = 1 T ∂ L ∂ htxt ⊤ , ∂ L ∂ W hh = ∑ t = 1 T prod ( ∂ L ∂ ht , ∂ ht ∂ W hh ) = ∑ t = 1 T ∂ L ∂ htht − 1 ⊤ 。\begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} & =\sum_{t=1}^{T} \operatorname{prod}\left(\frac{\部分 L}{\partial \boldsymbol{h}_{t}}, \frac{\partial \boldsymbol{h}_{t}}{\partial \boldsymbol{W}_{hx}}\right)=\ sum_{t=1}^{T} \frac{\partial L}{\partial \boldsymbol{h}_{t}} \boldsymbol{x}_{t}^{\top}, \\ \frac{ \partial L}{\partial \boldsymbol{W}_{hh}} & =\sum_{t=1}^{T} \operatorname{prod}\left(\frac{\partial L}{\partial \boldsymbol {h}_{t}}、\frac{\partial \boldsymbol{h}_{t}}{\partial \boldsymbol{W}_{hh}}\right)=\sum_{t=1}^{ T} \frac{\partial L}{\partial \boldsymbol{h}_{t}} \boldsymbol{h}_{t-1}^{\top} 。\end{整列}∂W _h ×∂L _∂W _ふーん∂L _=t = 1∑Tプロッド(∂h _た∂L _、∂W _h ×∂h _た)=t = 1∑T∂h _た∂L _バツt⊤、=t = 1∑Tプロッド(∂h _た∂L _、∂W _ふーん∂h _た)=t = 1∑T∂h _た∂L _ht − 1⊤.
したがって、勾配は時間 t とともに指数関数的に変化し、勾配の消失や勾配の爆発が起こりやすいという問題があります。(例: W hh W_{hh}Wふーん、式 4 を参照W hh ⊤ {W}_{hh}^{\top}Wふーん⊤電源の問題が関係している場合、W hh W_{hh}Wふーん< 1 は最終的に 0 になる傾向があり、勾配が消失します。W hh W_{hh}の場合Wふーん> 1 は最終的に無限大に達し、勾配爆発を引き起こします)。
GRU - ゲート付きリカレント ユニット
RNN 勾配の消失によって引き起こされる問題を軽減するためにゲートのリカレント ネットワークが導入され、情報の流れを制御するためにゲートの概念が導入され、モデルが長期情報をよりよく記憶し、勾配の消失を軽減できるようになります。 。
- ゲートのリセット: 忘れるべき情報とは
- Update Gate : 注意が必要な情報
活性化関数はシグモイドで、値の範囲は (0, 1) です。0 は忘れることを意味し、1 は保持を意味します。ゲートの値の範囲が (0, 1) の間にあることがわかります。
リセット ゲートは、候補隠れ状態の計算プロセス中に、前のタイム ステップの隠れ状態のどの情報を忘れるかに使用されます。
更新ゲートの機能は、現在のタイム ステップの隠れ状態を更新するときに、前のタイム ステップH t − 1 \boldsymbol{H}_{t-1} の隠れ状態を結合することです。Ht − 1。現在のタイムステップの候補隠れ状態H ~ t \tilde{\boldsymbol{H}}_{\mathrm{t}}H~た最終的なH t \boldsymbol{H}_{t}を取得しますHたR t = σ ( X t W xr + H t − 1 W hr + br ) Z t = σ ( X t W xz + H t − 1 W hz + bz ) \begin{aligned} \boldsymbol{R} _
{t} & =\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{xr}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{hr}+ \boldsymbol{b}_{r}\right) \\ \boldsymbol{Z}_{t} & =\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{xz}+\ボールドシンボル{H}_{t-1} \boldsymbol{W}_{hz}+\boldsymbol{b}_{z}\right) \end{aligned}RたZた=p( XたW× r+Ht − 1W時間_+br)=p( XたWxz _+Ht − 1WHz _+bz)
隠れた状態の候補
入力は前のタイム ステップの隠れ状態と共同計算されて、隠れ状態の計算に使用される候補隠れ状態が取得されます。ゲートをリセットすると、前のタイム ステップの隠れた状態が選択的に忘れられ、履歴情報がより適切に選択されます。
GRU:
H ~ t = Tanh ( X t W xh + ( R t ⊙ H t − 1 ) W hh + bh ) \tilde{\boldsymbol{H}}_{\mathrm{t}}=\tanh \left (\boldsymbol{X}_{\mathrm{t}} \boldsymbol{W}_{\mathrm{xh}}+\left(\boldsymbol{R}_{\mathrm{t}} \odot \boldsymbol{H }_{\mathrm{t}-1}\right) \boldsymbol{W}_{\mathrm{hh}}+\boldsymbol{b}_{\mathrm{h}}\right)H~た=胡散臭い( XたWxh+( Rた⊙Ht − 1)Wふーん+bふ)
元の RNN と比較してみます。
H t = ϕ ( X t W xh + H t − 1 W hh + bh ) \boldsymbol{H}_{\mathrm{t}}=\phi\left(\boldsymbol{X }_ {\mathrm{t}} \boldsymbol{W}_{\mathrm{xh}}+\boldsymbol{H}_{\mathrm{t}-1} \boldsymbol{W}_{\mathrm{hh} }+ \boldsymbol{b}_{\mathrm{h}}\right)Hた=ϕ( XたWxh+Ht − 1Wふーん+bふ)
隠れた状態
隠れ状態は、候補の隠れ状態と前のタイム ステップの隠れ状態を組み合わせることによって取得されます。
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \boldsymbol{H}_{\mathrm{t }}= \boldsymbol{Z}_{\mathrm{t}} \odot \boldsymbol{H}_{\mathrm{t}-1}+\left(1-\boldsymbol{Z}_{\mathrm{t }}\ 右) \odot \tilde{\boldsymbol{H}}_{\mathrm{t}}Hた=Zた⊙Ht − 1+( 1−Zた)⊙H~た
GRU の特徴:
ゲート機構はシグモイド活性化関数を採用しており、ゲート値は(0,1)となり、0は忘却、1は保持を意味します。
最初のタイム ステップから t-1 時間まで更新ゲートが 1 に保たれる場合、情報は現在のタイム ステップに効果的に送信されます。
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \boldsymbol{H}_{\mathrm{t}}=\boldsymbol{Z}_{\mathrm{t}} \odot \boldsymbol{H}_{\mathrm{t}-1}+\left(1-\boldsymbol{Z}_{\mathrm{t}}\right) \odot \tilde{\boldsymbol{H}}_{ \mathrm{t}}Hた=Zた⊙Ht − 1+( 1−Zた)⊙H~た
リセット ゲート: 最後のタイム ステップの隠れた状態を忘れるために使用されます
H ~ t = Tanh ( X t W xh + ( R t ⊙ H t − 1 ) W hh + bh ) \tilde{\boldsymbol{H}}_ { \mathrm{t}}=\tanh \left(\boldsymbol{X}_{\mathrm{t}} \boldsymbol{W}_{\mathrm{xh}}+\left(\boldsymbol{R}_{ \ mathrm{t}} \odot \boldsymbol{H}_{\mathrm{t}-1}\right) \boldsymbol{W}_{\mathrm{hh}}+\boldsymbol{b}_{\mathrm{ h }}\右)H~た=胡散臭い( XたWxh+( Rた⊙Ht − 1)Wふーん+bふ)
更新ゲート: 現在のタイム ステップの隠れ状態を更新するために使用されます
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \boldsymbol{H}_{\mathrm{t}}=\boldsymbol { Z}_{\mathrm{t}} \odot \boldsymbol{H}_{\mathrm{t}-1}+\left(1-\boldsymbol{Z}_{\mathrm{t}}\right) \ ドット \チルダ{\boldsymbol{H}}_{\mathrm{t}}Hた=Zた⊙Ht − 1+( 1−Zた)⊙H~た
LSTM — 長短期記憶ネットワーク
LSTM
情報伝達を制御する3つのゲートとメモリセルを導入
- 忘れられた門: 忘れるべき情報とは
- 入力ゲート: 現在のメモリセルにどのような情報を流す必要があるか
- 出力ゲート: どのメモリ情報が隠れ状態に流れ込むか
- メモリセル: 特別な隠れた状態、履歴情報を記憶
I t = σ ( X t W xi + H t − 1 W hi + bi ) F t = σ ( X t W xf + H t − 1 W hf + bf ) O t = σ ( X t W xo + H t − 1 W ho + bo ) \begin{aligned} \boldsymbol{I}_{t} & =\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{xi}+\boldsymbol{ H}_{t-1} \boldsymbol{W}_{hi}+\boldsymbol{b}_{i}\right) \\ \boldsymbol{F}_{t} & =\sigma\left(\boldsymbol {X}_{t} \boldsymbol{W}_{xf}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{hf}+\boldsymbol{b}_{f}\right) \\ \boldsymbol{O}_{t} & =\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{xo}+\boldsymbol{H}_{t-1} \boldsymbol {W}_{ho}+\boldsymbol{b}_{o}\right) \end{aligned}私たFた○た=p( XたW×私+Ht − 1Wこんにちは+b私は)=p( XたW× f+Ht − 1Wふふ_+bふ)=p( XたW× ○+Ht − 1Wほら_+bああ)
候補メモリセル
メモリセル:過去の瞬間情報を保存する特別な隠れ状態として理解できます
C ~ t = Tanh ( X t W xc + H t − 1 W hc + bc ) \tilde{\boldsymbol{C}}_{\mathrm {t }}=\tanh \left(\boldsymbol{X}_{\mathrm{t}} \boldsymbol{W}_{\mathrm{xc}}+\boldsymbol{H}_{\mathrm{t}- 1} \boldsymbol{W}_{\mathrm{hc}}+\boldsymbol{b}_{\mathrm{c}}\right)C~た=胡散臭い( XたWxc+Ht − 1WHC+bc)
メモリセルと隠れ状態
メモリ セルは、候補メモリ セルと前のタイム ステップのメモリ セルを組み合わせて取得されます
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t \boldsymbol{C}_{t}=\boldsymbol{F}_{t } \odot \boldsymbol{C}_{t-1}+\boldsymbol{I}_{t} \odot \tilde{\boldsymbol{C}}_{\mathrm{t}}Cた=Fた⊙Ct − 1+私た⊙C~た
メモリ セルの情報は出力ゲートによって制御され、隠れ状態
に流れます。 H t = O t ⊙ Tanh ( C t ) \boldsymbol{H}_{\mathrm{t}}=\boldsymbol{O}_{\ mathrm{t}} \ odot \tanh \left(\boldsymbol{C}_{\mathrm{t}}\right)Hた=○た⊙胡散臭い( Cた)