PyTorch RNN 输入维度,输出维度

版权声明:转载请注明出处及原文地址。 https://blog.csdn.net/zl1085372438/article/details/87164924

batch_first = True

输入:batch, seq_len,  input_size
h0:num_layers * num_directions, batch, hidden_size

输出:batch,seq_len, num_directions * hidden_size
hn:num_layers * num_directions, batch, hidden_size

batch_first = false

输入:seq_len, batch,  input_size
h0:num_layers * num_directions, batch, hidden_size

输出:seq_len, batch, num_directions * hidden_size
hn:num_layers * num_directions, batch, hidden_size

猜你喜欢

转载自blog.csdn.net/zl1085372438/article/details/87164924
今日推荐