Use of tf.dynamic_rnn in tensorflow, understanding of outputs and state

tf.dynamic_rnn

tensorflow's dynamic_rnn, we use a small example to illustrate its usage, assuming that the input of your RNN is [2, 3, 4], where 2 is batch_size, 3 is the maximum length of text, generally called num_steps or seq_length, 4 is embedding_size . We assume that the second text is only 2 in length, and the remaining 1 is padded using the 0-padding method. dynamic_rnn returns two parameters: outputs, last_states, where outputs are [2, 3, 4], that is, the output of each iterative hidden state, which includes all hidden layer states in training, and last_states is composed of (c, h) The tuple composed of, the size is [batch, hidden_size] that is [2,2].

There is no difference here, but dynamic has a parameter: sequence_length, which is used to specify the length of each example. For example, in the above example, we set sequence_length to [3,2], indicating that the effective length of the first example is 3. The effective length of the second example is 2. When we pass in this parameter, for the second batch, TensorFlow will not calculate the padding after 2, and its last_states will repeat the last_states of step 2 until step 3 , and outputs beyond step 2 will be zeroed.

Specifically look at the code and output:


import tensorflow as tf
import numpy as np

# Generate 2 batch data, sentence length is 3, embedding size is 4
 X = np.random.randn( 2 , 3 , 4 )

# 第二个batch长度为2
X[1,2:] = 0
X_lengths = [3, 2]
print(X)
# cell = tf.contrib.rnn.BasicLSTMCell(num_units=64, state_is_tuple=True)
cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=2,state_is_tuple=True)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)            

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    o=sess.run(outputs)
    s=sess.run(last_states)
    print ( 'output \n ' , o)
     print ( 'last_o \n ' , o[: , - 1 , :])# get the last output from output

    print ( '--------------------' )
     print ( 's \n ' , s)
     print ( 'sc \n ' , sc) # This is the gate The weight of the unit, no need to
     print here ( 'sh \n ​​' , sh) #sh is the state of the last output

produces the following output:

First of all, you can see that the last line of the second sentence is all 0, indicating that the sentence length is 2, and the following 0 is equivalent to pading.

Then generate two outputs through LSTM, output and state

In the output, it is found that after the second sentence is trained, the final state is all 0, but using sh in the state can find that the last output is not 0, but it is consistent with the last time in the output that is not 0, as follows With a few highlighted numbers, the output is the same. So if you want to use dynamic_rnn to get the output, you only need the last state output, just call sh directly.






Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=326069828&siteId=291194637