tensorflow学习之static_rnn使用详解

版权声明:微信公众号:数据挖掘与机器学习进阶之路。本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013230189/article/details/82804316

tf.nn.static_rnn

Aliases:

  1. tf.contrib.rnn.static_rnn
  2. tf.nn.static_rnn

使用指定的RNN神经元创建循环神经网络

tf.nn.static_rnn(

    cell,

    inputs,

    initial_state=None,

    dtype=None,

    sequence_length=None,

    scope=None

)

参数说明:

  • cell:用于神经网络的RNN神经元,如BasicRNNCell,BasicLSTMCell
  • inputs:一个长度为T的list,list中的每个元素为一个Tensor,Tensor形如:[batch_size,input_size]
  • initial_state:RNN的初始状态,如果cell.state_size是一个整数,则它必须是适当类型和形如[batch_size,cell.state_size]的张量。如cell.state_size是一个元组,那么它应该是一个张量元组,对于cell.state_size中的s,应该是具有形如[batch_size,s]的张量的元组。
  • dtype:初始状态和预期输出的数据类型。可选参数。
  • sequence_length:指定每个输入的序列的长度。大小为batch_size的向量。
  • scope:变量范围

返回值:

一个(outputs,state)对

outputs:一个长度为T的list,list中的每个元素是每个输入对应的输出。例如一个时间步对应一个输出。

state:最终的状态

代码实例:

import tensorflow as tf



x=tf.Variable(tf.random_normal([2,4,3])) #[batch_size,timesteps,embedding_dim]

x=tf.unstack(x,axis=1) #按时间步展开

n_neurons = 5 #输出神经元数量



basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)

output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell,x, dtype=tf.float32)



print(len(output_seqs)) #四个时间步

print(output_seqs[0]) #每个时间步输出一个张量

print(output_seqs[1]) #每个时间步输出一个张量

print(states) #隐藏状态

输出如下:

4

Tensor("rnn/basic_rnn_cell/Tanh:0", shape=(2, 5), dtype=float32)

Tensor("rnn/basic_rnn_cell/Tanh_1:0", shape=(2, 5), dtype=float32)

Tensor("rnn/basic_rnn_cell/Tanh_3:0", shape=(2, 5), dtype=float32)

 

猜你喜欢

转载自blog.csdn.net/u013230189/article/details/82804316