"Hands-on deep learning"-57 long short-term memory network LSTM

Mushen's version of "Learning Deep Learning by Hands" study notes, recording the learning process, please buy books for detailed content.

B station video link
open source tutorial link

Long Short-Term Memory Network (LSTM)

For a long time, hidden variable models have the problem of long-term information preservation and short-term missing inputs. One of the earliest approaches to this problem was long short-term memory networks.

The design of the LSTM network was inspired by the logic gates of a computer.

insert image description here

Gated memory element:
Forget gate: Decrease the value towards 0
Input gate: Decide whether to ignore the input data
Output gate: Decide whether to use the hidden state

insert image description here

by 3 with sigmoid sigmoidFully connected layer processing of s i g m o i d activation function:

insert image description here

Candidate memory units, similar to the 3 gates described above, but use the tanh function as the activation function, and the function value range is (-1, 1):

insert image description here

Memory unit: input gate I t I_tItControls how much to use from C ~ t \tilde{C}_tC~tnew data, while the forget gate F i F_iFiControls how many past memories C t − 1 C_{t-1} to keepCt1Content.

If the forget gate is always 1 and the input gate is always 0, then the past memory cell C t − 1 C_{t-1}Ct1will be saved over time and passed to the current time step. This design is introduced to mitigate the vanishing gradient problem and better capture long-distance dependencies in sequences.

insert image description here

Finally define the hidden state H t H_tHtThe computation of , which is where the output gate comes into play. In LSTM networks, it is simply a gated version of the tanh of the memory element. This ensures that H t H_tHtThe value of is always in the interval (-1, 1).

As long as the output gate is close to 1, we can effectively pass all the memory information to the prediction part, while for the output gate close to 0, we only keep all the information inside the memory cell without updating the hidden state.

Some literature considers memory cells as a special type of hidden state, which have the same shape as the hidden state, and are designed to record additional information.

insert image description here

Summarize

Long short-term memory networks are typical latent variable autoregressive models with important state controls. However, due to the long-range dependence of sequences, the cost of training LSTM networks and other sequence models (such as gated recurrent units) is quite high. Transformer is its advanced alternative model.

LSTM can alleviate gradient explosion and gradient disappearance.

Only the hidden state is passed to the output layer (Y), while the memory is entirely internal information.

insert image description here

hands-on learning

Long Short-Term Memory Network - LSTM

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数

    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    
    return params
# H 和 C 的初始化
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))
def lstm(inputs, state, params):

    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    
    outputs = []
    for X in inputs:

        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i) # 输入门
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f) # 遗忘门
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o) # 输出门
        
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c) # 候选记忆单元
        
        C = F * C + I * C_tilda # 记忆元
        
        H = O * torch.tanh(C) # 隐状态

        Y = (H @ W_hq) + b_q # 输出
        outputs.append(Y)
        
    return torch.cat(outputs, dim=0), (H, C)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

insert image description here

Concise implementation

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

insert image description here

Guess you like

Origin blog.csdn.net/cjw838982809/article/details/132579946