Introduction to Deep Learning (61) Recurrent Neural Network - Long Short-term Memory Network LSTM

foreword

The core content comes from blog link 1 blog link 2 I hope you can support the author a lot
This article is used for records to prevent forgetting

Recurrent neural network - long short-term memory network LSTM

courseware

long short-term memory network

Forget Gate: Decrement value towards 0
Input Gate: Decide whether to ignore input data
Output Gate: Decide whether to use hidden state

Door

insert image description here

candidate memory unit

insert image description here

memory unit

insert image description here

hidden state

insert image description here

Summarize

I t = σ ( X t W xi + H t − 1 W hi + bi ) , F t = σ ( X t W xf + H t − 1 W hf + bf ) , O t = σ ( X t W xo + H t − 1 W ho + bo ) , \begin{aligned}\begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{ H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\\mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W} _{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{ X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o),\\ \end{aligned}\ end{aligned}ItFtOt=s ( XtWxi+Ht1Whi+bi),=s ( XtWxf+Ht1Whf+bf),=s ( XtWxo+Ht1Wh o+bo),
C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) , C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . H t = O t ⊙ tanh ⁡ ( C t ) . \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),\\ \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.\\ \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t). C~t=fishy ( X .)tWxc+Ht1Whc+bc),Ct=FtCt1+ItC~t.Ht=Otfishy ( _t).
insert image description here

Textbook

For a long time, the latent variable model has the problem of long-term information preservation and short-term input missing. One of the earliest approaches to this problem was long short-term memory (LSTM). It has many of the same properties as a gated recurrent unit. Interestingly, the design of long-term short-term memory network is slightly more complicated than that of gated recurrent unit, but it was born nearly 20 years earlier than gated recurrent unit.

1 gated memory cell

It can be said that the design of the long short-term memory network is inspired by the logic gates of the computer. Long short-term memory networks are introduced 记忆元(memory cell), or simply called 单元(cell). Some literatures believe that memory cells are a special type of hidden state, they have the same shape as the hidden state, and their design purpose is to record additional information. To control the memory cells, we need many gates. One of the gates is used to output entries from the cell, which we will call 输出门(output gate). Another gate is used to decide when to read data into the cell, which we will call 输入门(input gate). We also need a mechanism to reset the contents of the cell, which is managed 遗忘门(forget gate)by the same design motivation as the gated recurrent cell, which can decide when to memorize or ignore the input in the hidden state through a dedicated mechanism. Let's see how this works in practice.

1.1 Input gate, forget gate and output gate

Just like in the gated recurrent unit, the input of the current time step and the hidden state of the previous time step are fed into the gates of the LSTM network as data, as shown in the figure. They are processed by three fully connected layers with sigmoid activation functions to compute the values ​​of the input, forget and output gates. Therefore, the values ​​of all three gates are at ( 0 , 1 ) (0, 1)(0,1 ) within the range.
insert image description here
Let's refine the mathematical expression of the long short-term memory network. Suppose there ishhh hidden units, batch sizennn , the input number isddd . Therefore, the input isX t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d}XtRn × d , the hidden state of the previous time step isH t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}Ht1Rn × h . Correspondingly, time stepttThe gate of t is defined as follows: The input gate isI t ∈ R n × h \mathbf{I}_t \in \mathbb{R}^{n \times h}ItRn × h , the forget gate isF t ∈ R n × h \mathbf{F}_t \in \mathbb{R}^{n \times h}FtRn × h , the output gate isO t ∈ R n × h \mathbf{O}_t \in \mathbb{R}^{n \times h}OtRFor n × h , the following functional forms are:
I t = σ ( X t W xi + H t − 1 W hi + bi ) , F t = σ ( X t W xf + H t − 1 W hf + bf ) , . O t = σ ( X t W xo + H t − 1 W ho + bo ) , \begin{split}\begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{ W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\\mathbf{F}_t &= \sigma(\ mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O }_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o ), \end{aligned}\end{split}ItFtOt=s ( XtWxi+Ht1Whi+bi),=s ( XtWxf+Ht1Whf+bf),=s ( XtWxo+Ht1Wh o+bo),
其中 W x i , W x f , W x o ∈ R d × h \mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h} Wxi,Wxf,WxoRd × hW hi , W hf , W ho ∈ R h × h \mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb {R}^{h \times h}Whi,Whf,Wh oRh × h是权重parameter,bi , bf , bo ∈ R 1 × h \mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h }bi,bf,boR1 × h is the bias parameter.

1.2 Candidate Memory Elements

Since the operation of various gates has not been specified, let us first introduce 候选记忆元(candidate memory cell) C ~ t ∈ R n × h \tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h}C~tRn × h . Its calculation is similar to the calculation of the three gates described above, but uses the tanh function as the activation function, and the value range of the function is( − 1 , 1 ) (-1, 1)(1,1 ) . The following derives the equation at time step:
C ~ t = tanh ( X t W xc + H t − 1 W hc + bc ) , \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{ X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),C~t=fishy ( X .)tWxc+Ht1Whc+bc),
其中 W x c ∈ R d × h \mathbf{W}_{xc} \in \mathbb{R}^{d \times h} WxcRd×h W h c ∈ R h × h \mathbf{W}_{hc} \in \mathbb{R}^{h \times h} WhcRh × h is the weight parameter,bc ∈ R 1 × h \mathbf{b}_c \in \mathbb{R}^{1 \times h}bcR1 × h is the bias parameter.

The candidate memory cells are shown in the figure.
insert image description here

1.3 Memories

In a gated recurrent unit, there is a mechanism to control input and forgetting (or skipping). Similarly, in LSTM networks, there are also two gates for this purpose: Input gate I t \mathbf{I}_tItThe control takes how much from C ~ t \tilde{\mathbf{C}}_tC~tnew data, and the forget gate F t \mathbf{F}_tFtControl how many past memory elements C t − 1 ∈ R n × h \mathbf{C}_{t-1} \in \mathbb{R}^{n \times h}Ct1RThe content of n × h . Using element-wise multiplication, we get:
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t- 1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.Ct=FtCt1+ItC~t.If
the forget gate is always 1 and the input gate is always 0, then the past memory elementC t − 1 \mathbf{C}_{t-1}Ct1will be saved over time and passed to the current time step. This design is introduced to alleviate the vanishing gradient problem and better capture long-distance dependencies in sequences.

In this way, we get the flow chart of calculating the memory element, as shown in the figure
insert image description here

1.4 Hidden state

Finally, we need to define how to calculate the hidden state H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h}HtRn × h , and this is where the output gate comes into play. In LSTM networks, it is simply a gated version of the tanh of the memory element. This ensures thatH t \mathbf{H}_tHtThe value of is always in the interval ( − 1 , 1 ) (-1, 1)(1,1)内:
H t = O t ⊙ tanh ⁡ ( C t ) . \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t). Ht=Otfishy ( _t) .
As long as the output gate is close to 1, we are able to effectively pass all the memory information to the prediction part, while for the output gate close to 0, we only keep all the information inside the memory cell without updating the hidden state.

The figure below provides a graphical representation of the data flow
insert image description here

2 Implementation from scratch

Now, we implement the LSTM network from scratch. We start by loading the Time Machine dataset.

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

2.1 Initialize model parameters

Next, we need to define and initialize the model parameters. As mentioned earlier, hyperparameters num_hiddensdefine the number of hidden units. We initialize the weights according to a Gaussian distribution with standard deviation 0.01 and set the bias term to 0.

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

2.2 Define the model

In the initialization function, the hidden state of the long-term short-term memory network needs to return a 额外memory element, the value of which is 0, and the shape is (批量大小,隐藏单元数). Thus, we get the following state initialization.

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

The definition of the actual model is the same as we discussed before: three gates and one additional memory cell are provided. Note that only the hidden state is passed to the output layer, and the memory element C t \mathbf{C}_tCtNot directly involved in output calculations.

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

2.3 Training and Prediction

RNNModelScratchLet's train an LSTM network by instantiating the class introduced in the RNN Implementation section , just as we did in the GRU section.

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

output

perplexity 1.4, 22462.5 tokens/sec on cuda:0
time traveller for the brow henint it aneles a overrecured aback
travellerifilby freenotin s dof nous be and the filing and

3 Concise implementation

Using the high-level API, we can directly instantiate the LSTM model. The high-level API encapsulates all the configuration details described above. This code runs much faster because it uses compiled operators instead of Python to handle many of the details explained earlier.

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

output:

perplexity 1.1, 330289.2 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

Long short-term memory networks are typical latent variable autoregressive models with important state controls. Many variants of it have been proposed over the years, e.g., multilayers, residual connections, different types of regularization. However, the cost of training LSTM networks and other sequence models such as gated recurrent units is quite high due to the long-range dependencies of sequences. In what follows, we will cover more advanced alternative models such as Transformer.

4 Summary

  • LSTM networks have three types of gates: input gates, forget gates, and output gates.

  • The hidden layer output of the LSTM network consists of "hidden states" and "memory elements". Only the hidden state is passed to the output layer, while the memory is entirely internal information.

  • Long short-term memory networks can alleviate gradient disappearance and gradient explosion.

references

[1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.

Guess you like

Origin blog.csdn.net/qq_52358603/article/details/128376487