Introducción al aprendizaje profundo (61) Red neuronal recurrente - Red de memoria a largo plazo a corto plazo LSTM

prefacio

El contenido principal proviene del enlace 1 del blog. Enlace 2 del blog. Espero que puedas apoyar mucho al autor.
Este artículo se usa para registros para evitar el olvido.

Red neuronal recurrente - red de memoria a corto plazo LSTM

cursos

red de memoria a corto plazo largo

Forget Gate: Disminuye el valor hacia 0
Input Gate: Decide si ignorar los datos de entrada
Output Gate: Decide si usar el estado oculto

Puerta

inserte la descripción de la imagen aquí

unidad de memoria candidata

inserte la descripción de la imagen aquí

unidad de memoria

inserte la descripción de la imagen aquí

estado oculto

inserte la descripción de la imagen aquí

Resumir

yo t = σ ( X t W xi + H t - 1 W hola + bi ) , F t = σ ( X t W xf + H t - 1 W hf + bf ) , O t = σ ( X t W xo + H t − 1 W ho + bo ) , \begin{alineado}\begin{alineado} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{ H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\\mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W} _{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{ X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o),\\ \end{alineado}\ fin {alineado}ItFtOt=s ( XtWx yo+Ht 1Whola+byo) ,=s ( XtWx f+Ht 1Wh f+bf) ,=s ( XtWxo _+Ht 1Whola _+bo) ,
C ~ t = tanh ( X t W xc + H t - 1 W hc + antes de Cristo ) , C t = F t ⊙ C t - 1 + yo t ⊙ C ~ t . H t = O t ⊙ tanh ⁡ ( C t ) . \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_ {hc} + \mathbf{b}_c),\\ \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \ tilde{\mathbf{C}}_t.\\ \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t).C~t=sospechoso ( X.)tWx c+Ht 1Wh c+bdo) ,Ct=FtCt 1+ItC~t.Ht=Otsospechoso ( _t) .
inserte la descripción de la imagen aquí

Libro de texto

Durante mucho tiempo, el modelo de variable latente tiene el problema de la conservación de la información a largo plazo y la falta de entrada a corto plazo. Una de las primeras aproximaciones a este problema fue la memoria a corto plazo (LSTM). Tiene muchas de las mismas propiedades que una unidad recurrente cerrada. Curiosamente, el diseño de la red de memoria a corto plazo a largo plazo es un poco más complicado que el de la unidad recurrente cerrada, pero nació casi 20 años antes que la unidad recurrente cerrada.

1 celda de memoria cerrada

Se puede decir que el diseño de la red de memoria a corto plazo está inspirado en las puertas lógicas de la computadora. Se introducen las redes de memoria a corto plazo largo 记忆元(memory cell), o simplemente se les llama 单元(cell). Algunas literaturas creen que las celdas de memoria son un tipo especial de estado oculto, tienen la misma forma que el estado oculto y su propósito de diseño es registrar información adicional. Para controlar las celdas de memoria, necesitamos muchas puertas. Una de las puertas se usa para generar entradas desde la celda, que llamaremos 输出门(output gate). Se usa otra puerta para decidir cuándo leer datos en la celda, a la que llamaremos 输入门(input gate). También necesitamos un mecanismo para restablecer el contenido de la celda, que es administrado 遗忘门(forget gate)por la misma motivación de diseño que la celda recurrente cerrada, que puede decidir cuándo memorizar o ignorar la entrada en el estado oculto a través de un mecanismo dedicado. Veamos cómo funciona esto en la práctica.

1.1 Puerta de entrada, puerta de olvido y puerta de salida

Al igual que en la unidad recurrente cerrada, la entrada del paso de tiempo actual y el estado oculto del paso de tiempo anterior se introducen en las puertas de la red LSTM como datos, como se muestra en la figura. Son procesados ​​por tres capas totalmente conectadas con funciones de activación sigmoide para calcular los valores de las puertas de entrada, olvido y salida. Por lo tanto, los valores de las tres puertas están en ( 0 , 1 ) (0, 1)( 0 ,1 ) dentro del rango.
inserte la descripción de la imagen aquí
Refinemos la expresión matemática de la red de memoria a corto plazo. Supongamos que hayhhh unidades ocultas, tamaño de lotennn , el número de entrada esddre . Por lo tanto, la entrada esX t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d}XtRn × d , el estado oculto del paso de tiempo anterior esH t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}Ht 1Rn × h . En consecuencia, el paso de tiempottLa puerta de t se define como sigue: La puerta de entrada esI t ∈ R n × h \mathbf{I}_t \in \mathbb{R}^{n \times h}ItRn × h , la puerta de olvido esF t ∈ R n × h \mathbf{F}_t \in \mathbb{R}^{n \times h}FtRn × h , la puerta de salida esO t ∈ R n × h \mathbf{O}_t \in \mathbb{R}^{n \times h}OtRPara n × h , las siguientes formas funcionales son:
I t = σ ( X t W xi + H t − 1 W hi + bi ) , F t = σ ( X t W xf + H t − 1 W hf + bf ) O t = σ ( X t W xo + H t − 1 W ho + bo ) , \begin{split}\begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \ mathbf{ W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ \mathbf{F}_t &= \sigma (\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf {O }_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b }_o ), \end{alineado}\end{dividido}ItFtOt=s ( XtWx yo+Ht 1Whola+byo) ,=s ( XtWx f+Ht 1Wh f+bf) ,=s ( XtWxo _+Ht 1Whola _+bo) ,
其中W xi , W xf , W xo ∈ R d × h \mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R} ^{d\veces h}Wx yo,Wx f,Wxo _Rd × hW hi , W hf , Who ho ∈ R h × h \mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb {R}^{h\veces h}Whola,Wh f,Whola _Rh × h是权重parámetro,bi , bf , bo ∈ R 1 × h \mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h }byo,bf,boR1 × h es el parámetro de sesgo.

1.2 Elementos de memoria candidatos

Dado que no se ha especificado la operación de varias puertas, primero introduzcamos 候选记忆元(candidate memory cell) C ~ t ∈ R n × h \tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h}C~tRn × h . Su cálculo es similar al cálculo de las tres puertas descritas anteriormente, pero utiliza la función tanh como función de activación, y el rango de valores de la función es( − 1 , 1 ) (-1, 1)( -1 , _1 ) . Lo siguiente deriva la ecuación en el paso de tiempo:
C ~ t = tanh ( X t W xc + H t − 1 W hc + bc ) , \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{ X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),C~t=sospechoso ( X.)tWx c+Ht 1Wh c+bdo) ,
其中W xc ∈ R d × h \mathbf{W}_{xc} \in \mathbb{R}^{d \times h}Wx cRd × hW hc ∈ R h × h \mathbf{W}_{hc} \in \mathbb{R}^{h \times h}Wh cRh × h es el parámetro de peso,bc ∈ R 1 × h \mathbf{b}_c \in \mathbb{R}^{1 \times h}bdoR1 × h es el parámetro de sesgo.

Las celdas de memoria candidatas se muestran en la figura.
inserte la descripción de la imagen aquí

1.3 Memorias

En una unidad recurrente cerrada, hay un mecanismo para controlar la entrada y el olvido (u omisión). De manera similar, en las redes LSTM, también hay dos puertas para este propósito: Puerta de entrada I t \mathbf{I}_tItEl control toma cuánto de C ~ t \tilde{\mathbf{C}}_tC~tnuevos datos y la puerta de olvido F t \mathbf{F}_tFtControla cuántos elementos de la memoria pasada C t − 1 ∈ R n × h \mathbf{C}_{t-1} \in \mathbb{R}^{n \times h}Ct 1REl contenido de n × h . Usando la multiplicación por elementos, obtenemos:
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t - 1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.Ct=FtCt 1+ItC~t.Si
la puerta de olvido siempre es 1 y la puerta de entrada siempre es 0, entonces el elemento de memoria pasadaC t − 1 \mathbf{C}_{t-1}Ct 1se guardará con el tiempo y pasará al paso de tiempo actual. Este diseño se presenta para aliviar el problema del gradiente de fuga y capturar mejor las dependencias de larga distancia en las secuencias.

De esta forma, obtenemos el diagrama de flujo del cálculo del elemento de memoria, como se muestra en la figura
inserte la descripción de la imagen aquí

1.4 Estado oculto

Finalmente, necesitamos definir cómo calcular el estado oculto H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h}HtRn × h , y aquí es donde entra en juego la puerta de salida. En las redes LSTM, es simplemente una versión cerrada del tanh del elemento de memoria. Esto asegura queH t \mathbf{H}_tHtEl valor de está siempre en el intervalo ( − 1 , 1 ) (-1, 1)( -1 , _1 )内:
H t = O t ⊙ tanh ⁡ ( C t ) . \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t).Ht=Otsospechoso ( _t) Siempre que la puerta de salida esté cerca de
1, podemos pasar efectivamente toda la información de la memoria a la parte de predicción, mientras que para la puerta de salida cerca de 0, solo mantenemos toda la información dentro de la celda de memoria sin actualizar el estado oculto.

La siguiente figura proporciona una representación gráfica del flujo de datos
inserte la descripción de la imagen aquí

2 Implementación desde cero

Ahora, implementamos la red LSTM desde cero. Comenzamos cargando el conjunto de datos de Time Machine.

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

2.1 Inicializar parámetros del modelo

A continuación, necesitamos definir e inicializar los parámetros del modelo. Como se mencionó anteriormente, los hiperparámetros num_hiddensdefinen el número de unidades ocultas. Inicializamos los pesos de acuerdo con una distribución gaussiana con desviación estándar de 0,01 y establecemos el término de sesgo en 0.

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

2.2 Definir el modelo

En la función de inicialización, el estado oculto de la red de memoria a corto plazo a largo plazo debe devolver un 额外elemento de memoria, cuyo valor es 0 y la forma es (批量大小,隐藏单元数). Por lo tanto, obtenemos la siguiente inicialización de estado.

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

La definición del modelo real es la misma que discutimos antes: se proporcionan tres puertas y una celda de memoria adicional. Tenga en cuenta que solo el estado oculto se pasa a la capa de salida y el elemento de memoria C t \mathbf{C}_tCtNo está directamente involucrado en los cálculos de salida.

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

2.3 Entrenamiento y predicción

RNNModelScratchEntrenemos una red LSTM instanciando la clase presentada en la sección Implementación de RNN , tal como lo hicimos en la sección GRU.

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

producción

perplexity 1.4, 22462.5 tokens/sec on cuda:0
time traveller for the brow henint it aneles a overrecured aback
travellerifilby freenotin s dof nous be and the filing and

3 Implementación concisa

Usando la API de alto nivel, podemos instanciar directamente el modelo LSTM. La API de alto nivel encapsula todos los detalles de configuración descritos anteriormente. Este código se ejecuta mucho más rápido porque usa operadores compilados en lugar de Python para manejar muchos de los detalles explicados anteriormente.

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

producción:

perplexity 1.1, 330289.2 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

Las redes de memoria a corto plazo largo son modelos autorregresivos de variable latente típicos con importantes controles de estado. Se han propuesto muchas variantes a lo largo de los años, por ejemplo, multicapas, conexiones residuales, diferentes tipos de regularización. Sin embargo, el costo de entrenar redes LSTM y otros modelos de secuencia, como unidades recurrentes cerradas, es bastante alto debido a las dependencias de largo alcance de las secuencias. A continuación, cubriremos modelos alternativos más avanzados como el Transformer.

4 Resumen

  • Las redes LSTM tienen tres tipos de puertas: puertas de entrada, puertas de olvido y puertas de salida.

  • La salida de la capa oculta de la red LSTM consta de "estados ocultos" y "elementos de memoria". Solo el estado oculto se pasa a la capa de salida, mientras que la memoria es información completamente interna.

  • Las redes de memoria a corto plazo pueden aliviar la desaparición de gradientes y la explosión de gradientes.

referencias

[1] Hochreiter, S. y Schmidhuber, J. (1997). Memoria a corto plazo. Computación neuronal, 9(8), 1735-1780.

Supongo que te gusta

Origin blog.csdn.net/qq_52358603/article/details/128376487
Recomendado
Clasificación