rnn, lstm, gru三种神经网络数据格式详解

在这里插入图片描述
先上一张rnn图来直观理解一下。
RNN,LSTM,GRU三种神经网络的输入数据格式都是**[seq_len, batch_size, input_dim]**

  • seq_len: 输入的每句话的长度,图中表示为n_steps
  • batch_size: 输入的句子数量,图中表示为batch
  • input_dim: 输入的每个字的向量维度表示,比如图中的n_input的维度是3

rnn

rnn = nn.RNN(input_dim, hidden_dim, num_layers, bidirectional=False, batch_first=False)
x = torch.randn(seq_len, batch_size, input_dim)
out, ht = rnn(x) 
# out: [seq_len, batch_size, num_directions * hidden_dim]
# ht: [num_layers*num_directions, batch_size, hidden_dim]

out[-1]和ht[-1]相等,隐藏单元就是输出的最后一个单元

lstm

lstm = nn.LSTM()
lstm = nn.LSTM(input_dim, hidden_dim, num_layers, bidirectional=False, batch_first=False)
# 输入序列seq= 10,batch =3,输入维度=50
x = torch.randn(seq_len, batch_size, input_dim)
out, (hn, cn) = lstm(x) # 使用默认的全 0 隐藏状态
# out: [seq_len, batch_size, num_directions * hidden_dim]
# hn=hc: [num_layers*num_directions, batch_size, hidden_dim]

out[-1, :, :]和hn[-1, :, :]相等,隐藏单元就是输出的最后一个单元

gru

gru和rnn非常像

gru = nn.GRU(input_dim, hidden_dim, num_layers, bidirectional=False, batch_first=False) 
x = torch.randn(seq_len, batch_size, input_dim)
out, hn = gru(x)
# out: [seq_len, batch_size, num_directions * hidden_dim]
# hn: [num_layers*num_directions, batch_size, hidden_dim]

代码实例

三种网络的输出out: [6, 10, 20]分别表示 [seq_len, batch_size, hidden_dim]
hn: [1, 10, 20]分别表示 [num_layers * num_directions, batch_size, hidden_dim],这里的num_layers和num_directions都是1。在定义网络时还有其它参数,例如bidirectional和 batch_first等等,这里没有介绍,可自行阅读pytorch官网。

import torch 
import torch.nn as nn

# 数据输入
x = torch.randn(6, 10, 200)  # [seq_len, batch_size, input_dim]

# 定义三种网络
rnn = nn.RNN(200, 20, 1)    # [input_dim, hidden_size, num_layers]
lstm = nn.LSTM(200, 20, 1)  # [input_dim, hidden_size, num_layers]
gru = nn.GRU(200, 20, 1)	# [input_dim, hidden_size, num_layers]

# rnn
out, hn = rnn(x)
print(out.shape, hn.shape)
# torch.Size([6, 10, 20]) torch.Size([1, 10, 20])

# lstm
out, (hn,cn) = lstm(x)
print(out.shape, hn.shape, cn.shape)
# torch.Size([6, 10, 20]) torch.Size([1, 10, 20]) torch.Size([1, 10, 20])

# gru
out, hn = gru(x)
print(out.shape, hn.shape)
# torch.Size([6, 10, 20]) torch.Size([1, 10, 20])

猜你喜欢

转载自blog.csdn.net/lyj223061/article/details/108213385