深層学習入門 (60) リカレント ニューラル ネットワーク - ゲート型リカレント ユニット GRU
序文
核心的な内容はブログリンク1ブログリンク2からです作者をたくさん応援していただければ幸いです
この記事は忘れないための記録用です
リカレント ニューラル ネットワーク - ゲート付きリカレント ユニット GRU
コースウェア
シーケンスに焦点を当てる
すべての観察が同じように重要であるわけではありません。
関連する観察のみを記憶するには、次のことが必要です。
- 従うことができる仕組み(アップデートゲート)
- 忘却の仕組み(リセットゲート)
ドア
隠れた状態の候補
隠れた状態
要約する
R t = σ ( X t W xr + H t − 1 W hr + br ) 、Z t = σ ( X t W xz + H t − 1 W hz + bz ) 、tanh ( X t W xh + ( R t ⊙ H t − 1 ) W hh + bh ) 、H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t \begin{整列}\begin{整列} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{ W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{ t-1} \mathbf{W}_{hz} + \mathbf{b}_z),\\ \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R }_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),\\\mathbf{H}_t = \mathbf{Z} _t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \assignment{\mathbf{H}}_t. \end{整列}\end{整列}Rた=s ( XたW× r+Ht − 1W時間_+br)、Zた=s ( XたWxz _+Ht − 1WHz _+bz)、怪しい( X .)たW×時間+( Rた⊙Ht − 1)Wふーん+bふ)、Hた=Zた⊙Ht − 1+( 1−Zた)⊙H~た.
教科書
「時間による逆伝播」のセクションでは、リカレント ニューラル ネットワークで勾配がどのように計算されるか、および連続する行列積によって勾配が消失または爆発する可能性があるという問題について説明しました。実際のこの勾配異常の重要性について簡単に考えてみましょう。
-
将来のすべての観測を予測するために、早期の観測が非常に重要である状況に遭遇する可能性があります。最初の観測値にチェックサムが含まれており、シーケンスの最後でチェックサムが正しいかどうかを確認することが目標であるという極端なケースを考えてみましょう。この場合、最初の補題の影響が重要です。重要な初期情報をメモリセルに保存するための何らかのメカニズムが必要です。このようなメカニズムがなければ、後続のすべての観測に影響を与えるため、この観測に非常に大きな勾配を割り当てる必要があります。
-
一部のトークンに関連する観測値がない状況が発生する場合があります。たとえば、Web ページのコンテンツに対してセンチメント分析が実行される場合、Web ページによって伝えられるセンチメントとは無関係の補助 HTML コードがいくつか存在する可能性があります。非表示の状態表現でそのようなトークンをスキップする何らかのメカニズムが必要です。
-
シーケンスの部分間に論理的な切れ目がある状況に遭遇する場合があります。たとえば、本の章と章の間、または証券の弱気市場と強気市場の間に移行がある場合があります。この場合、状態の内部表現をリセットする方法があれば便利です。
このような問題を解決するために、学界では多くの方法が提案されています。最も初期の方法の 1 つは「長短期記憶」(長短期記憶、LSTM)です。ゲート反復ユニット (GRU) は、わずかに簡素化されたバリアントで、通常は同等のパフォーマンスを提供し、計算が大幅に高速になります。ゲート反復ユニットの方が簡単なので、それから始めます。
1 ゲートされた隠れ状態
Gated Recurrent Unit と通常の RNN の主な違いは次のとおりです。前者は隠れ状態のゲートをサポートします。これは、モデルには、いつ隠れ状態を更新する必要があるか、いつ隠れ状態をリセットする必要があるかを決定するための特殊なメカニズムがあることを意味します。これらのメカニズムは学習可能であり、上記の問題に対処できます。たとえば、最初のトークンが非常に重要な場合、モデルは最初の観測後に隠れた状態を更新しないことを学習します。同様に、モデルは無関係な何気ない観察をスキップすることも学習できます。最後に、モデルは必要に応じて非表示状態をリセットする方法も学習します。以下では、さまざまな種類のゲートについて詳しく説明します。
1.1 リセットゲートとアップデートゲート
まず、重置门(reset gate)
ゲートを導入して更新します(update gate)
。それらを( 0 , 1 ) (0, 1)として設計します。( 0 ,1 )凸の組み合わせができるように区間内のベクトル。リセット ゲートを使用すると、「まだ覚えておきたい」過去の状態の量を制御でき、更新ゲートを使用すると、古い状態のコピーである新しい状態の数を制御できます。
これらのゲートを構築することから始めます。次の図は、ゲート反復ユニットのリセット ゲートと更新ゲートの入力を示しています。入力は、現在のタイム ステップの入力と前のタイム ステップの隠れ状態によって与えられます。2 つのゲートの出力は、シグモイド活性化関数を使用した 2 つの完全に接続された層によって与えられます。
ゲート反復ユニットの数学的表現を見てみましょう。特定のタイムステップttに対してt、入力がミニバッチであると仮定しますX t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d}バツた∈Rn × d (サンプル数nnn、数値ddd )、最後のタイムステップの隠れ状態はH t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} です。Ht − 1∈Rn × h (隠れユニットの数)。すると、リセットゲートR t ∈ R n × h \mathbf{R}_t \in \mathbb{R}^{n \times h}Rた∈Rn × hおよび更新ゲートZ t ∈ R n × h \mathbf{Z}_t \in \mathbb{R}^{n \times h}Zた∈Rn × hの関数を決定します
。 R t = σ ( X t W xr + H t − 1 W hr + br )、 Z t = σ ( X t W xz + H t − 1 W hz + bz ) 、 \ begin {split}\begin{align} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W } _{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t - 1} \mathbf{W}_{hz} + \mathbf{b}_z)、\end{整列}\end{分割}Rた=s ( XたW× r+Ht − 1W時間_+br)、Zた=s ( XたWxz _+Ht − 1WHz _+bz)、
W xr の場合、W xz ∈ R d × h \mathbf{W}_{xr}, \mathbf{W}_{xz} \mathbb{R}^{d \times h }W× r、Wxz _∈Rd × h和W hr , W hz ∈ R h × h \mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}W時間_、WHz _∈Rh × hは重みパラメータ、br , bz ∈ R 1 × h \mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}br、bz∈R1 × hはバイアスパラメータです。ブロードキャスト メカニズムが合計中にトリガーされることに注意してください。シグモイド関数を使用して、入力値を区間( 0 , 1 ) (0, 1)( 0 ,1 )。
1.2 隠れた状態の候補
次に、ゲートR t \mathbf{R}_tをリセットしましょう。RたとRNN中H t = ϕ (X t W xh + H t − 1 Wh + bh) 。\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh} + \mathbf{ b}_h)。Hた=ϕ ( XたW×時間+Ht − 1Wふーん+bふ) . .に通常の隠れ状態更新メカニズムを統合し、タイム ステップtttの候选隐状态(candidate hidden state)
H ~ t ∈ R n × h \tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}H~た∈Rn × h
H ~ t = Tanh ( X t W xh + ( R t ⊙ H t − 1 ) W hh + bh ) , \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b} _h)、H~た=怪しい( X .)たW×時間+( Rた⊙Ht − 1)Wふーん+bふ) ,
其中W xh ∈ R d × h \mathbf{W}_{xh} \in \mathbb{R}^{d \times h}W×時間∈Rd × h和W hh ∈ R h × h \mathbf{W}_{hh} \in \mathbb{R}^{h \times h}Wふーん∈Rh × hは重みパラメータ、bh ∈ R 1 × h \mathbf{b}_h \in \mathbb{R}^{1 \times h}bふ∈R1 × hはバイアス項目、記号⊙ \odot⊙はアダマール積 (要素ごとの積) 演算子です。ここでは、tanh 非線形活性化関数を使用して、候補の隠れ状態の値が区間( - 1 , 1 ) (-1, 1)( − 1 、1 )。
与H t = ϕ (X t W xh + H t − 1 W hh + bh) 。\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh} + \mathbf{ b}_h)。Hた=ϕ ( XたW×時間+Ht − 1Wふーん+bふ) .上の式のR t \mathbf{R}_tRた和H t − 1 \mathbf{H}_{t-1}Ht − 1の要素を乗算すると、以前の状態の影響を軽減できます。ゲートがリセットされるたびにR t \mathbf{R}_tRたin の項が 1 に近い場合、通常の RNN と同様に、通常のリカレント ニューラル ネットワークが復元されます。リセットゲートR t \mathbf{R}_tの場合Rたすべての近い項目は 0 であり、隠れ状態の候補はX t \mathbf{X}_tですバツた入力としての多層パーセプトロンの結果。したがって、既存の非表示状態はすべて ** 重置
** ** デフォルト値になります。
以下の図は、リセット ゲート適用後の計算フローを示しています。
1.4 隠し状態
上記の計算結果は隠れ状態の候補にすぎません。更新ゲートZ t \mathbf{Z}_tを組み合わせる必要があります。Zた効果。このステップでは、新しい隠れ状態H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} を決定します。Hた∈Rn × h が古い状態H t − 1 \mathbf{H}_{t-1}Ht − 1そして新しい候補状態H ~ t \tilde{\mathbf{H}}_tH~た。ゲートZ t \mathbf{Z}_tを更新しますZたH t − 1 \mathbf{H}_{t-1}でのみ必要ですHt − 1和H ~ t \チルダ{\mathbf{H}}_tH~たこの目標は、それらの間で要素ごとの凸の組み合わせを実行することで達成できます。これにより、ゲート型リカレント ユニットの最終更新式が導かれます: H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \mathbf{H}_t = \mathbf{Z}_t \ odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t。Hた=Zた⊙Ht − 1+( 1−Zた)⊙H~た.
ゲートZ t \mathbf{Z}_tZた1 に近い場合、モデルは古い状態のみを保持する傾向があります。この時点で、X t \mathbf{X}_tからバツたの情報は基本的に無視され、依存関係チェーン内のタイム ステップ t を効果的にスキップします。逆に、Z t \mathbf{Z}_tの場合Zた0 に近い場合、新しい隠れ状態H t \mathbf{H}_tHたそれは候補の隠れ状態H ~ t \tilde{\mathbf{H}}_tに近くなります。H~た。これらの設計は、リカレント ニューラル ネットワークにおける勾配消失問題に対処し、タイム ステップ距離が長いシーケンスの依存関係をより適切に捕捉するのに役立ちます。たとえば、サブシーケンス全体のすべてのタイム ステップで更新ゲートが 1 に近い場合、シーケンスの長さに関係なく、シーケンスの開始タイム ステップの古い隠れ状態が簡単に保持され、シーケンスの最後に渡されます。シーケンス。
以下の図は、アップデート ゲートが有効になった後の計算フローを示しています。
要約すると、ゲート付きリカレント ユニットには次の 2 つの顕著な特徴があります。
-
リセット ゲートは、シーケンス内の短期的な依存関係を把握するのに役立ちます。
-
更新ゲートは、シーケンス内の長期的な依存関係を取得するのに役立ちます。
2 ゼロからの実装
ゲート反復ユニット モデルをより深く理解するために、それを最初から実装します。まず、Medium Time Machine データセットを読み取ります。
import torch
from torch import nn
from d2l import torch as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
2.1 モデルパラメータの初期化
次のステップは、モデル パラメーターを初期化することです。標準偏差 0.01 のガウス分布から重みを引き出し、バイアス項を 0 に設定します。ハイパーパラメーターはnum_hiddens
隠れユニットの数を定義し、更新ゲート、リセット ゲート、候補隠れ状態、出力層とバイアスに関連するすべての重みをインスタンス化します。
def get_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device)*0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xz, W_hz, b_z = three() # 更新门参数
W_xr, W_hr, b_r = three() # 重置门参数
W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
2.2 モデルの定義
次に、非表示状態の初期化関数を定義しますinit_gru_state
。セクション「ゼロからの RNN 実装」で定義された関数と同様にinit_rnn_state
、この関数は(批量大小,隐藏单元个数)
、値がすべて 0 であるshape のテンソルを返します。
def init_gru_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device), )
これで、Gated Recurrent Unit モデルを定義する準備が整いました。モデルのアーキテクチャは、重み更新式がより複雑であることを除いて、基本的な RNN ユニットと同じです。
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = H @ W_hq + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
2.3 トレーニングと予測
トレーニングと予測は以前とまったく同じように機能します。トレーニング後、トレーニング セットのパープレキシティと、それぞれ「タイム トラベラー」と「トラベラー」というプレフィックスが付いた予測シーケンスのパープレキシティを出力します。
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
出力:
perplexity 1.3, 28030.1 tokens/sec on cuda:0
time traveller wetheving of my investian of the fromaticalllesp
travellery celaner betareabreart of the three dimensions an
3 簡潔な実装
高レベル API には、前に紹介したすべての構成の詳細が含まれているため、ゲートされた反復ユニット モデルを直接インスタンス化できます。このコードは、Python の代わりにコンパイルされた演算子を使用して、前に説明した詳細の多くを処理するため、はるかに高速に実行されます。
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
出力:
perplexity 1.1, 334788.1 tokens/sec on cuda:0
time traveller with a slight accession ofcheerfulness really thi
travelleryou can show black is white by argument said filby
4 まとめ
-
ゲートされたリカレント ニューラル ネットワークは、タイム ステップ距離が長いシーケンスに対する依存関係をより適切に捕捉できます。
-
リセット ゲートは、シーケンス内の短期的な依存関係をキャプチャするのに役立ちます。
-
更新ゲートは、シーケンス内の長期的な依存関係を取得するのに役立ちます。
-
リセット ゲートが開いている場合、ゲート付きリカレント ユニットには基本的なリカレント ニューラル ネットワークが含まれており、更新ゲートが開いている場合、ゲート付きリカレント ユニットはサブシーケンスをスキップできます。
参考文献
[1] Cho, K.、Van Merriënboer, B.、Bahdanau, D.、Bengio, Y. (2014)。ニューラル機械翻訳の特性について: エンコーダーとデコーダーのアプローチ。arXiv プレプリント arXiv:1409.1259。
[2] Chung, J.、Gulcehre, C.、Cho, K.、Bengio, Y. (2014)。シーケンスモデリングにおけるゲートリカレントニューラルネットワークの経験的評価。arXiv プレプリント arXiv:1412.3555。