"Aprendizaje profundo práctico" -57 red de memoria a corto plazo LSTM

La versión de Mushen de las notas de estudio "Aprendizaje profundo a mano" registra el proceso de aprendizaje; compre libros para obtener contenido detallado.

Enlace de video de la estación B
Enlace de tutorial de código abierto

Red de memoria a corto plazo (LSTM)

Durante mucho tiempo, los modelos de variables ocultas han tenido el problema de preservar la información a largo plazo y faltar entradas a corto plazo. Una de las primeras aproximaciones a este problema fueron las redes de memoria a corto plazo.

El diseño de la red LSTM se inspiró en las puertas lógicas de una computadora.

inserte la descripción de la imagen aquí

Elemento de memoria cerrado:
Olvidar puerta: disminuir el valor hacia 0
Puerta de entrada: decidir si ignorar los datos de entrada
Puerta de salida: decidir si usar el estado oculto

inserte la descripción de la imagen aquí

por 3 con sigmoide sigmoideProcesamiento de capa completamente conectado de la función de activación s i g m o i d :

inserte la descripción de la imagen aquí

Unidades de memoria candidatas, similares a las 3 puertas descritas anteriormente, pero usan la función tanh como función de activación y el rango de valores de la función es (-1, 1):

inserte la descripción de la imagen aquí

Unidad de memoria: puerta de entrada I t I_tItControla cuánto usar desde C ~ t \tilde{C}_tC~tnuevos datos, mientras que la puerta de olvido F i F_iFyoControla cuántos recuerdos pasados ​​C t − 1 C_{t-1} conservarCt - 1Contenido.

Si la puerta de olvido es siempre 1 y la puerta de entrada es siempre 0, entonces la celda de memoria pasada C t − 1 C_{t-1}Ct - 1se guardará con el tiempo y se pasará al paso de tiempo actual. Este diseño se introduce para mitigar el problema del gradiente que desaparece y capturar mejor las dependencias de larga distancia en secuencias.

inserte la descripción de la imagen aquí

Finalmente defina el estado oculto H t H_thtEl cálculo de , que es donde entra en juego la puerta de salida. En las redes LSTM, es simplemente una versión cerrada del tanh del elemento de memoria. Esto asegura que H t H_thtEl valor de siempre está en el intervalo (-1, 1).

Mientras la puerta de salida esté cerca de 1, podemos pasar efectivamente toda la información de la memoria a la parte de predicción, mientras que para la puerta de salida cerca de 0, solo mantenemos toda la información dentro de la celda de memoria sin actualizar el estado oculto.

Parte de la literatura considera que las células de memoria son un tipo especial de estado oculto, que tienen la misma forma que el estado oculto y están diseñadas para registrar información adicional.

inserte la descripción de la imagen aquí

Resumir

Las redes de memoria a corto plazo son modelos autorregresivos de variables latentes típicos con importantes controles de estado. Sin embargo, debido a la dependencia de largo alcance de las secuencias, el costo de entrenar redes LSTM y otros modelos de secuencia (como unidades recurrentes cerradas) es bastante alto. Transformer es su modelo alternativo avanzado.

LSTM puede aliviar la explosión y desaparición del gradiente.

Sólo el estado oculto se pasa a la capa de salida (Y), mientras que la memoria es información enteramente interna.

inserte la descripción de la imagen aquí

aprendizaje práctico

Red de memoria a largo plazo - LSTM

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数

    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    
    return params
# H 和 C 的初始化
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))
def lstm(inputs, state, params):

    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    
    outputs = []
    for X in inputs:

        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i) # 输入门
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f) # 遗忘门
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o) # 输出门
        
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c) # 候选记忆单元
        
        C = F * C + I * C_tilda # 记忆元
        
        H = O * torch.tanh(C) # 隐状态

        Y = (H @ W_hq) + b_q # 输出
        outputs.append(Y)
        
    return torch.cat(outputs, dim=0), (H, C)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

inserte la descripción de la imagen aquí

Implementación concisa

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/cjw838982809/article/details/132579946
Recomendado
Clasificación