03- Implementación de LSTM desde cero

Una introducción a los algoritmos de memoria a largo plazo.

El algoritmo LSTM (memoria larga a corto plazo) es un algoritmo de red neuronal recurrente (RNN) común que se utiliza para procesar datos de secuencia y funciona muy bien al procesar datos de series de tiempo.

La idea principal del algoritmo LSTM es agregar una unidad de memoria basada en RNN, que puede ayudar a la red a recordar estados pasados ​​y actualizarlos cuando sea necesario. Al mismo tiempo, LSTM también controla el flujo de información a través de tres puertas: puerta de olvido, puerta de entrada y puerta de salida . Estas puertas permiten a la red elegir cuándo olvidar estados antiguos, aceptar nuevas entradas y generar el estado actual.

Implementación del segundo algoritmo.

2.1 Paquete de guía

# 从零实现
import torch
from torch import nn
import dltools

2.2 Importar datos de entrenamiento

batch_size, num_steps = 32, 35
train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps=num_steps)

2.3 Inicializar los parámetros del modelo

# 初始化模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size
    
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    def three():
        return (normal((num_inputs, num_hiddens)),
               normal((num_hiddens, num_hiddens)),
               torch.zeros(num_hiddens, device=device))
    W_xi, W_hi, b_i = three()  # 输入参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    
    # 输出层
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    # 附加梯度
    params = [W_xi, W_hi, b_i,
              W_xf, W_hf, b_f,
              W_xo, W_ho, b_o,
              W_xc, W_hc, b_c, W_hq, b_q]
    
    for param in params:
        param.requires_grad_(True)
    
    return params

Inicialice los parámetros del modelo LSTM . Primero, las dimensiones de entrada y salida se calculan en función del tamaño del vocabulario y el número de unidades ocultas. A continuación, se define una función normal para generar números aleatorios que obedecen a la distribución normal estándar y multiplicar los números aleatorios por 0,01 para controlar el rango de valores inicial de los parámetros.

A continuación, se define una función interna tres para generar tres parámetros. Cada parámetro es una tupla que contiene la matriz de ponderación y el vector de sesgo entre las unidades de entrada y ocultas. Aquí se utiliza el método de inicialización normal para generar la matriz de peso y el vector de sesgo se inicializa a todos ceros .

Luego, llame a la función interna tres para inicializar los parámetros de entrada, puerta de olvido, puerta de salida y celdas de memoria candidatas, respectivamente .

A continuación, se utiliza la función normal para inicializar los parámetros de la capa de salida, incluida la matriz de peso y el vector de polarización de la unidad oculta y la salida.

Finalmente, coloque todos los parámetros en una lista y establezca require_grad en True, lo que indica que es necesario calcular el gradiente de los parámetros.

Devuelve la lista de parámetros.

2.4 Inicializar estado oculto

# 初始化隐藏状态和记忆元
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
           torch.zeros((batch_size, num_hiddens), device=device))

Inicialice los estados ocultos y las celdas de memoria de la red de memoria a corto plazo (LSTM). Los parámetros de entrada incluyen el tamaño del lote (batch_size), el número de unidades ocultas (num_hiddens) y el dispositivo informático (device).

La función devuelve una tupla que contiene dos tensores, que representan el estado oculto y el elemento de memoria . Estos tensores tienen tamaño (batch_size, num_hiddens) y se inicializan completamente a ceros.

2.5 Definir la estructura LSTM

# 定义 LSTM 主体结构
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i,
     W_xf, W_hf, b_f,
     W_xo, W_ho, b_o,
     W_xc, W_hc, b_c, W_hq, b_q] = params
    
    (H, C) = state
    outputs = []
    
    # 准备开始进行前向传播
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

Entrenamiento de tres modelos

3.1 Código de prueba

# 测试代码
X = torch.arange(10).reshape((2, 5))
num_hiddens = 512
net = dltools.RNNModelScratch(len(vocab), num_hiddens, dltools.try_gpu(), get_lstm_params, init_lstm_state, lstm)
state = net.begin_state(X.shape[0], dltools.try_gpu())
Y, new_state = net(X.to(dltools.try_gpu()), state)

3.2 Formación en línea

# 训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, dltools.try_gpu()
num_epochs, lr = 500, 1
model = dltools.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

Cuatro implementaciones de pytorch

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = dltools.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

Supongo que te gusta

Origin blog.csdn.net/March_A/article/details/132841428
Recomendado
Clasificación