Una introducción a los algoritmos de memoria a largo plazo.
El algoritmo LSTM (memoria larga a corto plazo) es un algoritmo de red neuronal recurrente (RNN) común que se utiliza para procesar datos de secuencia y funciona muy bien al procesar datos de series de tiempo.
La idea principal del algoritmo LSTM es agregar una unidad de memoria basada en RNN, que puede ayudar a la red a recordar estados pasados y actualizarlos cuando sea necesario. Al mismo tiempo, LSTM también controla el flujo de información a través de tres puertas: puerta de olvido, puerta de entrada y puerta de salida . Estas puertas permiten a la red elegir cuándo olvidar estados antiguos, aceptar nuevas entradas y generar el estado actual.
Implementación del segundo algoritmo.
2.1 Paquete de guía
# 从零实现
import torch
from torch import nn
import dltools
2.2 Importar datos de entrenamiento
batch_size, num_steps = 32, 35
train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps=num_steps)
2.3 Inicializar los parámetros del modelo
# 初始化模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device) * 0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xi, W_hi, b_i = three() # 输入参数
W_xf, W_hf, b_f = three() # 遗忘门参数
W_xo, W_ho, b_o = three() # 输出门参数
W_xc, W_hc, b_c = three() # 候选记忆元参数
# 输出层
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
params = [W_xi, W_hi, b_i,
W_xf, W_hf, b_f,
W_xo, W_ho, b_o,
W_xc, W_hc, b_c, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
Inicialice los parámetros del modelo LSTM . Primero, las dimensiones de entrada y salida se calculan en función del tamaño del vocabulario y el número de unidades ocultas. A continuación, se define una función normal para generar números aleatorios que obedecen a la distribución normal estándar y multiplicar los números aleatorios por 0,01 para controlar el rango de valores inicial de los parámetros.
A continuación, se define una función interna tres para generar tres parámetros. Cada parámetro es una tupla que contiene la matriz de ponderación y el vector de sesgo entre las unidades de entrada y ocultas. Aquí se utiliza el método de inicialización normal para generar la matriz de peso y el vector de sesgo se inicializa a todos ceros .
Luego, llame a la función interna tres para inicializar los parámetros de entrada, puerta de olvido, puerta de salida y celdas de memoria candidatas, respectivamente .
A continuación, se utiliza la función normal para inicializar los parámetros de la capa de salida, incluida la matriz de peso y el vector de polarización de la unidad oculta y la salida.
Finalmente, coloque todos los parámetros en una lista y establezca require_grad en True, lo que indica que es necesario calcular el gradiente de los parámetros.
Devuelve la lista de parámetros.
2.4 Inicializar estado oculto
# 初始化隐藏状态和记忆元
def init_lstm_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device),
torch.zeros((batch_size, num_hiddens), device=device))
Inicialice los estados ocultos y las celdas de memoria de la red de memoria a corto plazo (LSTM). Los parámetros de entrada incluyen el tamaño del lote (batch_size), el número de unidades ocultas (num_hiddens) y el dispositivo informático (device).
La función devuelve una tupla que contiene dos tensores, que representan el estado oculto y el elemento de memoria . Estos tensores tienen tamaño (batch_size, num_hiddens) y se inicializan completamente a ceros.
2.5 Definir la estructura LSTM
# 定义 LSTM 主体结构
def lstm(inputs, state, params):
[W_xi, W_hi, b_i,
W_xf, W_hf, b_f,
W_xo, W_ho, b_o,
W_xc, W_hc, b_c, W_hq, b_q] = params
(H, C) = state
outputs = []
# 准备开始进行前向传播
for X in inputs:
I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
C = F * C + I * C_tilda
H = O * torch.tanh(C)
Y = (H @ W_hq) + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H, C)
Entrenamiento de tres modelos
3.1 Código de prueba
# 测试代码
X = torch.arange(10).reshape((2, 5))
num_hiddens = 512
net = dltools.RNNModelScratch(len(vocab), num_hiddens, dltools.try_gpu(), get_lstm_params, init_lstm_state, lstm)
state = net.begin_state(X.shape[0], dltools.try_gpu())
Y, new_state = net(X.to(dltools.try_gpu()), state)
3.2 Formación en línea
# 训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, dltools.try_gpu()
num_epochs, lr = 500, 1
model = dltools.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
Cuatro implementaciones de pytorch
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = dltools.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)