LSTM in Pytorch

[PyTorch] rnn,lstm,gru中输入输出维度 - 简书
LSTM神经网络输入输出究竟是怎样的? - 知乎
pytorch文档
在这里插入图片描述
可以把上面的每一列看成是有厚度的(如 128 维).

为什么输入可以是任意长度?

在这里插入图片描述
以 RNN 为例, 使用的网络是同一个网络, 只是在不同的时间步接收输入而已.

即 RNN 在一个时间步只接收一个输入 x t x_t . 所以我们在下面 pytorch 设置 LSTM 网络参数的时候不用设置 time_step, 只需要设置输入向量的维度就可以.

注意:不要被展开图给迷惑了.

设置网络参数

torch.nn.LSTM( input_size, hidden_size, num_layers )
      输入特征的维度 ‘num_units’

接收输入

Inputs: input, (h_0, c_0)
   ‘三维’‘三维’‘三维’

  • input of shape (seq_len, batch, input_size) — batch 指 一个 batch 所含的序列个数, LSTM 一次处理一个 batch 的所有序列
  • h_0 of shape (num_layers * num_directions, batch, hidden_size)
  • c_0 of shape (num_layers * num_directions, batch, hidden_size)

输出

Outputs: output, (h_n, c_n)

  • output of shape (seq_len, batch, num_directions * hidden_size)
  • h_n of shape (num_layers * num_directions, batch, hidden_size)
  • c_n of shape (num_layers * num_directions, batch, hidden_size)

Example

>>> rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) # 它才不管你的输入长度, 任意长度都可以
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> print(output.size())
torch.Size([5, 3, 20])
表示输出 3 个 batch, 每个 batch 形状为 [5, 20]

发布了108 篇原创文章 · 获赞 7 · 访问量 4420

猜你喜欢

转载自blog.csdn.net/weixin_44795555/article/details/102894908