Summarize nn.lstm in PYTORCH (since the official documentation includes parameters and examples)

Refer to the official pytorch documentation
https://pytorch.org/docs/master/nn.html#torch.nn.LSTM

First go to the original pictureFigure 1 Introduction to the internal principles of LSTM

| Figure 2 Introduction of key parameters
Here are the key parameter introduction
input_size: the number of input features hidden_size: the number of
hidden layer features
num_layers: this is the number of LSTMs integrated by the model. Remember here is how many LSTMs are stacked in the model. The default is generally 1
bias: use Without bias, the default is
batch_first: if the default is true, the format of the input and output tensor is (batch, seq, feature)
that is [batch_size, time_step, input_size] [batch size, sequence length, number of features]
dropout: Default 0 If not 0, the dropout rate
is bidirectional: whether it is a bidirectional LSTM or not by default

Figure 3 is the input and output interface
The following only considers the single item. The
input format is lstm (input, (h_0, c_0))
. The input is (seq_len, batch, input_size). The default batch_first is false, otherwise the first two
reorders. H_0 is the format (num_layers * num_directions , batch, hidden_size) tensor It contains the initial hidden state of each element in the batch.
If it is a bidirectional lstm num_dire… = 2 otherwise = 1
c_0 is a tensor of the format (seq_len, batch, input_size) It contains each element in the batch
If the initial cell state is not provided by h_0 and c_0, the default is 0


The output format is (output, (h_n, c_n))
output is a tensor with the shape (seq_len, batch, num_directions * hidden_size), which contains the output feature h_t (from the last layer of each t of LSTM)
h_n is the shape (num_layers * num_directions, batch, hidden_size) tensor, including t = seq_len (that is, the end of the sequence), the hidden state value
c_n is a tensor with the shape (num_layers * num_directions, batch, hidden_size), and the cell contains t = seq_len (that is, the end of the sequence) value

Figure 4 Variable explanation and examples
Look directly at the example

rnn = nn.LSTM(10,20,2)#输入向量维数10, 隐藏元维度20, 2个LSTM层串联(若不写则默认为1)
input = torch.randn(5,3,10)#输入(seq_len, batch, input_size) 序列长度为5 batch为3 输入维度为10
h0 = torch.randn(2,3,20)#h_0(num_layers * num_directions, batch, hidden_size)  num_layers = 2 ,batch=3 ,hidden_size = 20
c0 = torch.randn(2,3,20)#同上
output, (hn,cn) = rnn(input, (h0,c0))
Published 43 original articles · praised 14 · 20,000+ views

Guess you like

Origin blog.csdn.net/weixin_41545780/article/details/89890440