版权声明:微信公众号:数据挖掘与机器学习进阶之路。本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013230189/article/details/82804316
tf.nn.static_rnn
Aliases:
- tf.contrib.rnn.static_rnn
- tf.nn.static_rnn
使用指定的RNN神经元创建循环神经网络
tf.nn.static_rnn(
cell,
inputs,
initial_state=None,
dtype=None,
sequence_length=None,
scope=None
)
参数说明:
- cell:用于神经网络的RNN神经元,如BasicRNNCell,BasicLSTMCell
- inputs:一个长度为T的list,list中的每个元素为一个Tensor,Tensor形如:[batch_size,input_size]
- initial_state:RNN的初始状态,如果cell.state_size是一个整数,则它必须是适当类型和形如[batch_size,cell.state_size]的张量。如cell.state_size是一个元组,那么它应该是一个张量元组,对于cell.state_size中的s,应该是具有形如[batch_size,s]的张量的元组。
- dtype:初始状态和预期输出的数据类型。可选参数。
- sequence_length:指定每个输入的序列的长度。大小为batch_size的向量。
- scope:变量范围
返回值:
一个(outputs,state)对
outputs:一个长度为T的list,list中的每个元素是每个输入对应的输出。例如一个时间步对应一个输出。
state:最终的状态
代码实例:
import tensorflow as tf
x=tf.Variable(tf.random_normal([2,4,3])) #[batch_size,timesteps,embedding_dim]
x=tf.unstack(x,axis=1) #按时间步展开
n_neurons = 5 #输出神经元数量
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell,x, dtype=tf.float32)
print(len(output_seqs)) #四个时间步
print(output_seqs[0]) #每个时间步输出一个张量
print(output_seqs[1]) #每个时间步输出一个张量
print(states) #隐藏状态
输出如下:
4 Tensor("rnn/basic_rnn_cell/Tanh:0", shape=(2, 5), dtype=float32) Tensor("rnn/basic_rnn_cell/Tanh_1:0", shape=(2, 5), dtype=float32) Tensor("rnn/basic_rnn_cell/Tanh_3:0", shape=(2, 5), dtype=float32) |