Introducción a NLP (5) Propagación directa de RNN

Serie de tutoriales de introducción a la PNL

Capítulo 1 Representación distribuida del lenguaje natural y las palabras

Capítulo 2 Versión mejorada basada en métodos de conteo

Capítulo 3 Una breve introducción a word2vec

Capítulo 4 RNN



prefacio

Este capítulo simplemente implementará un proceso de propagación hacia adelante de rnn y luego verificará la corrección


1. Un ejemplo simple de propagación directa de RNN

import torch
import torch.nn as nn
batch_size, T = 2, 4 #批大小 序列长度
input_size, hidden_size = 3, 4
input = torch.randn(batch_size, T, input_size)
h_pre = torch.zeros(batch_size, hidden_size) #隐藏层

rnn = nn.RNN(input_size, hidden_size, batch_first=True)
output, hn = rnn(input, h_pre.unsqueeze(0)) #h.shape :D*layers,batch_size, hidden_size

print(output)
print(hn)

def rnn_forward(input,h_pre,w_ih,w_hh,bias_ih,bias_hh):
    batch_size, T, input_size = input.shape
    h_dim = w_ih.shape[0] #根据矩阵相乘的维度关系,wih*xt
    h_out = torch.zeros(batch_size, T, h_dim)

    for t in range(T):
        x = input[:, t, :].unsqueeze(2) #当前序列
        w_ih_batch = w_ih.unsqueeze(0).tile(batch_size,1,1)  #w_ih的维度[batch_size,h_dim ,input_size]
        w_hh_batch = w_hh.unsqueeze(0).tile(batch_size, 1, 1)

        w_t_x = torch.bmm(w_ih_batch, x).squeeze(-1)
        w_t_h = torch.bmm(w_hh_batch, h_pre.unsqueeze(2)).squeeze(-1)
        h_pre = torch.tanh(w_t_x + bias_ih + w_t_h  + bias_hh)

        h_out[:, t, :] = h_pre
    return h_out, h_pre.unsqueeze(0)

my_rnn_output, my_hn = rnn_forward(input, h_pre, rnn.weight_ih_l0, rnn.weight_hh_l0, rnn.bias_ih_l0,
                                   rnn.bias_hh_l0)

print('=================================')
print(my_rnn_output)
print(my_hn)

No hay dificultad en el código, lo principal es recordar la transformación de algunas dimensiones.
1

Comparándolo con la salida manuscrita,
inserte la descripción de la imagen aquí
se puede ver que esto no es un problema.


Resumir

Lo anterior es un breve resumen del contenido de hoy.

Supongo que te gusta

Origin blog.csdn.net/weixin_39524208/article/details/131949840
Recomendado
Clasificación