NLP 入門チュートリアル シリーズ
記事ディレクトリ
序文
この章では、rnn の順伝播プロセスを単純に実装し、その正しさを検証します。
1. RNN 順伝播の簡単な例
import torch
import torch.nn as nn
batch_size, T = 2, 4 #批大小 序列长度
input_size, hidden_size = 3, 4
input = torch.randn(batch_size, T, input_size)
h_pre = torch.zeros(batch_size, hidden_size) #隐藏层
rnn = nn.RNN(input_size, hidden_size, batch_first=True)
output, hn = rnn(input, h_pre.unsqueeze(0)) #h.shape :D*layers,batch_size, hidden_size
print(output)
print(hn)
def rnn_forward(input,h_pre,w_ih,w_hh,bias_ih,bias_hh):
batch_size, T, input_size = input.shape
h_dim = w_ih.shape[0] #根据矩阵相乘的维度关系,wih*xt
h_out = torch.zeros(batch_size, T, h_dim)
for t in range(T):
x = input[:, t, :].unsqueeze(2) #当前序列
w_ih_batch = w_ih.unsqueeze(0).tile(batch_size,1,1) #w_ih的维度[batch_size,h_dim ,input_size]
w_hh_batch = w_hh.unsqueeze(0).tile(batch_size, 1, 1)
w_t_x = torch.bmm(w_ih_batch, x).squeeze(-1)
w_t_h = torch.bmm(w_hh_batch, h_pre.unsqueeze(2)).squeeze(-1)
h_pre = torch.tanh(w_t_x + bias_ih + w_t_h + bias_hh)
h_out[:, t, :] = h_pre
return h_out, h_pre.unsqueeze(0)
my_rnn_output, my_hn = rnn_forward(input, h_pre, rnn.weight_ih_l0, rnn.weight_hh_l0, rnn.bias_ih_l0,
rnn.bias_hh_l0)
print('=================================')
print(my_rnn_output)
print(my_hn)
コードに難しいことはありません。主なことは、いくつかの次元の変換を覚えておくことです。
手書き出力と比較すると
問題ないことが分かります。
要約する
以上が本日の内容を簡単にまとめたものです。