Introduction to Deep Learning (61) Recurrent Neural Network - Long Short-term Memory Network LSTM
foreword
The core content comes from blog link 1 blog link 2 I hope you can support the author a lot
This article is used for records to prevent forgetting
Recurrent neural network - long short-term memory network LSTM
courseware
long short-term memory network
Forget Gate: Decrement value towards 0
Input Gate: Decide whether to ignore input data
Output Gate: Decide whether to use hidden state
Door
candidate memory unit
memory unit
hidden state
Summarize
I t = σ ( X t W xi + H t − 1 W hi + bi ) , F t = σ ( X t W xf + H t − 1 W hf + bf ) , O t = σ ( X t W xo + H t − 1 W ho + bo ) , \begin{aligned}\begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{ H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\\mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W} _{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{ X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o),\\ \end{aligned}\ end{aligned}ItFtOt=s ( XtWxi+Ht−1Whi+bi),=s ( XtWxf+Ht−1Whf+bf),=s ( XtWxo+Ht−1Wh o+bo),
C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) , C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . H t = O t ⊙ tanh ( C t ) . \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),\\ \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.\\ \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t). C~t=fishy ( X .)tWxc+Ht−1Whc+bc),Ct=Ft⊙Ct−1+It⊙C~t.Ht=Ot⊙fishy ( _t).
Textbook
For a long time, the latent variable model has the problem of long-term information preservation and short-term input missing. One of the earliest approaches to this problem was long short-term memory (LSTM). It has many of the same properties as a gated recurrent unit. Interestingly, the design of long-term short-term memory network is slightly more complicated than that of gated recurrent unit, but it was born nearly 20 years earlier than gated recurrent unit.
1 gated memory cell
It can be said that the design of the long short-term memory network is inspired by the logic gates of the computer. Long short-term memory networks are introduced 记忆元(memory cell)
, or simply called 单元(cell)
. Some literatures believe that memory cells are a special type of hidden state, they have the same shape as the hidden state, and their design purpose is to record additional information. To control the memory cells, we need many gates. One of the gates is used to output entries from the cell, which we will call 输出门(output gate)
. Another gate is used to decide when to read data into the cell, which we will call 输入门(input gate)
. We also need a mechanism to reset the contents of the cell, which is managed 遗忘门(forget gate)
by the same design motivation as the gated recurrent cell, which can decide when to memorize or ignore the input in the hidden state through a dedicated mechanism. Let's see how this works in practice.
1.1 Input gate, forget gate and output gate
Just like in the gated recurrent unit, the input of the current time step and the hidden state of the previous time step are fed into the gates of the LSTM network as data, as shown in the figure. They are processed by three fully connected layers with sigmoid activation functions to compute the values of the input, forget and output gates. Therefore, the values of all three gates are at ( 0 , 1 ) (0, 1)(0,1 ) within the range.
Let's refine the mathematical expression of the long short-term memory network. Suppose there ishhh hidden units, batch sizennn , the input number isddd . Therefore, the input isX t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d}Xt∈Rn × d , the hidden state of the previous time step isH t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}Ht−1∈Rn × h . Correspondingly, time stepttThe gate of t is defined as follows: The input gate isI t ∈ R n × h \mathbf{I}_t \in \mathbb{R}^{n \times h}It∈Rn × h , the forget gate isF t ∈ R n × h \mathbf{F}_t \in \mathbb{R}^{n \times h}Ft∈Rn × h , the output gate isO t ∈ R n × h \mathbf{O}_t \in \mathbb{R}^{n \times h}Ot∈RFor n × h , the following functional forms are:
I t = σ ( X t W xi + H t − 1 W hi + bi ) , F t = σ ( X t W xf + H t − 1 W hf + bf ) , . O t = σ ( X t W xo + H t − 1 W ho + bo ) , \begin{split}\begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{ W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\\mathbf{F}_t &= \sigma(\ mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O }_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o ), \end{aligned}\end{split}ItFtOt=s ( XtWxi+Ht−1Whi+bi),=s ( XtWxf+Ht−1Whf+bf),=s ( XtWxo+Ht−1Wh o+bo),
其中 W x i , W x f , W x o ∈ R d × h \mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h} Wxi,Wxf,Wxo∈Rd × h和W hi , W hf , W ho ∈ R h × h \mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb {R}^{h \times h}Whi,Whf,Wh o∈Rh × h是权重parameter,bi , bf , bo ∈ R 1 × h \mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h }bi,bf,bo∈R1 × h is the bias parameter.
1.2 Candidate Memory Elements
Since the operation of various gates has not been specified, let us first introduce 候选记忆元(candidate memory cell)
C ~ t ∈ R n × h \tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h}C~t∈Rn × h . Its calculation is similar to the calculation of the three gates described above, but uses the tanh function as the activation function, and the value range of the function is( − 1 , 1 ) (-1, 1)(−1,1 ) . The following derives the equation at time step:
C ~ t = tanh ( X t W xc + H t − 1 W hc + bc ) , \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{ X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),C~t=fishy ( X .)tWxc+Ht−1Whc+bc),
其中 W x c ∈ R d × h \mathbf{W}_{xc} \in \mathbb{R}^{d \times h} Wxc∈Rd×h和 W h c ∈ R h × h \mathbf{W}_{hc} \in \mathbb{R}^{h \times h} Whc∈Rh × h is the weight parameter,bc ∈ R 1 × h \mathbf{b}_c \in \mathbb{R}^{1 \times h}bc∈R1 × h is the bias parameter.
The candidate memory cells are shown in the figure.
1.3 Memories
In a gated recurrent unit, there is a mechanism to control input and forgetting (or skipping). Similarly, in LSTM networks, there are also two gates for this purpose: Input gate I t \mathbf{I}_tItThe control takes how much from C ~ t \tilde{\mathbf{C}}_tC~tnew data, and the forget gate F t \mathbf{F}_tFtControl how many past memory elements C t − 1 ∈ R n × h \mathbf{C}_{t-1} \in \mathbb{R}^{n \times h}Ct−1∈RThe content of n × h . Using element-wise multiplication, we get:
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t- 1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.Ct=Ft⊙Ct−1+It⊙C~t.If
the forget gate is always 1 and the input gate is always 0, then the past memory elementC t − 1 \mathbf{C}_{t-1}Ct−1will be saved over time and passed to the current time step. This design is introduced to alleviate the vanishing gradient problem and better capture long-distance dependencies in sequences.
In this way, we get the flow chart of calculating the memory element, as shown in the figure
1.4 Hidden state
Finally, we need to define how to calculate the hidden state H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h}Ht∈Rn × h , and this is where the output gate comes into play. In LSTM networks, it is simply a gated version of the tanh of the memory element. This ensures thatH t \mathbf{H}_tHtThe value of is always in the interval ( − 1 , 1 ) (-1, 1)(−1,1)内:
H t = O t ⊙ tanh ( C t ) . \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t). Ht=Ot⊙fishy ( _t) .
As long as the output gate is close to 1, we are able to effectively pass all the memory information to the prediction part, while for the output gate close to 0, we only keep all the information inside the memory cell without updating the hidden state.
The figure below provides a graphical representation of the data flow
2 Implementation from scratch
Now, we implement the LSTM network from scratch. We start by loading the Time Machine dataset.
import torch
from torch import nn
from d2l import torch as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
2.1 Initialize model parameters
Next, we need to define and initialize the model parameters. As mentioned earlier, hyperparameters num_hiddens
define the number of hidden units. We initialize the weights according to a Gaussian distribution with standard deviation 0.01 and set the bias term to 0.
def get_lstm_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device)*0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xi, W_hi, b_i = three() # 输入门参数
W_xf, W_hf, b_f = three() # 遗忘门参数
W_xo, W_ho, b_o = three() # 输出门参数
W_xc, W_hc, b_c = three() # 候选记忆元参数
# 输出层参数
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
b_c, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
2.2 Define the model
In the initialization function, the hidden state of the long-term short-term memory network needs to return a 额外
memory element, the value of which is 0, and the shape is (批量大小,隐藏单元数)
. Thus, we get the following state initialization.
def init_lstm_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device),
torch.zeros((batch_size, num_hiddens), device=device))
The definition of the actual model is the same as we discussed before: three gates and one additional memory cell are provided. Note that only the hidden state is passed to the output layer, and the memory element C t \mathbf{C}_tCtNot directly involved in output calculations.
def lstm(inputs, state, params):
[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
W_hq, b_q] = params
(H, C) = state
outputs = []
for X in inputs:
I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
C = F * C + I * C_tilda
H = O * torch.tanh(C)
Y = (H @ W_hq) + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H, C)
2.3 Training and Prediction
RNNModelScratch
Let's train an LSTM network by instantiating the class introduced in the RNN Implementation section , just as we did in the GRU section.
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
output
perplexity 1.4, 22462.5 tokens/sec on cuda:0
time traveller for the brow henint it aneles a overrecured aback
travellerifilby freenotin s dof nous be and the filing and
3 Concise implementation
Using the high-level API, we can directly instantiate the LSTM model. The high-level API encapsulates all the configuration details described above. This code runs much faster because it uses compiled operators instead of Python to handle many of the details explained earlier.
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
output:
perplexity 1.1, 330289.2 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby
Long short-term memory networks are typical latent variable autoregressive models with important state controls. Many variants of it have been proposed over the years, e.g., multilayers, residual connections, different types of regularization. However, the cost of training LSTM networks and other sequence models such as gated recurrent units is quite high due to the long-range dependencies of sequences. In what follows, we will cover more advanced alternative models such as Transformer.
4 Summary
-
LSTM networks have three types of gates: input gates, forget gates, and output gates.
-
The hidden layer output of the LSTM network consists of "hidden states" and "memory elements". Only the hidden state is passed to the output layer, while the memory is entirely internal information.
-
Long short-term memory networks can alleviate gradient disappearance and gradient explosion.
references
[1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.