TensorFlow笔记:dynamic_rnn

通常的,使用RNN的时候,我们需要指定num_step,也就是TensorFlow的roll step步数,但是对于变长的文本来说,指定num_step就不可避免的需要进行padding操作,TensorFlow使用了dynamic_padding方法实现自动padding,但是这还不够,因为在跑一遍RNN/LSTM之后,还是需要对padding部分的内容进行删除,我称之为“反padding”,无可避免的,我们就需要指定mask矩阵。很麻烦!而使用dynamic_rnn可以跳过padding部分的计算,减少计算量。

outputs, last_states = tf.nn.dynamic_rnn( cell=cell, dtype=tf.float32, sequence_length=x_lengths, inputs=x),其中cell是RNN节点,比如tf.contrib.rnn.BasicLSTMCel,x_lengths是每个文本的长度,x是0-padding以后的数据。

dynamic_rnn函数有两个输出,outputs, last_states。由于好奇这两个输出到底保存了什么信息,做了如下实验。

假设RNN的输入:
1)batch_size=2,一个batch里面有2个句子。
2)num_step=2,即最大的句子长度为2。
3)embedding_size=1,词向量长度为1。
4) rnn_size =64,神经元个数为64个。
4)batch中的两个句子,一个长度为2,一个长度为1。
如图所示:
这里写图片描述

dynamic_rnn的输出:
1)outputs的shape是{ ([64个元素],[64个元素]), ([64个元素],[64个元素]) },即shape=(2,2,64)。第一个([64个元素],[64个元素])是example1得到的y1,y2。第二个([64个元素],[64个元素])是example2得到的y1,y2,由于example2长度为1,所以第二个[64个元素]全为0。
2)last_states是由(c,h)组成的tuple,shape是( [64个元素],[64个元素] ),即shape=(2,64),直接输出两个样本更新后的结果,而不是每个样本结果都输出。与outputs不同,outputs是输出每个样本的y值,因为要与label求loss。example2的last_states将2步的输出重复第1步的输出。

可见,对于example2,TensorFlow对于1以后的padding就不计算了,其last_states将重复第1步的last_states至第2步,而outputs中超过1步的结果将会被置零,节省了不少的计算开销。

此外,为了计算loss,需要将outputs变成batch_size行,每行是一个样本所有y值的拼接,即(y1y2)/(64个元素64个元素),代码为output = tf.reshape(outputs, [-1, batch_size])。

实验代码如下:

import tensorflow as tf
import numpy as np
# 创建输入数据
# batch_size=2,time_step=2,embedding_size=1,rnn_size=64
X = np.random.randn(2, 2, 1)

# 第二个example长度为1
X[1,1:] = 0
X_lengths = [2, 1]

cell = tf.nn.rnn_cell.BasicRNNCell(num_units=64)

outputs, last_states = tf.nn.dynamic_rnn(cell=cell, dtype=tf.float64, sequence_length=X_lengths, inputs=X)
output = tf.reshape(outputs, [-1, 2])

result = tf.contrib.learn.run_n({"outputs": outputs, "last_states": last_states}, n=1, feed_dict=None)

print(result[0])

assert result[0]["outputs"].shape == (2, 2, 64)

# 第二个example中的outputs超过1步(即第2步)的值应该为0
assert (result[0]["outputs"][1,1,:] == np.zeros(cell.output_size)).all()

结果如下:

‘outputs’:
[
[
[ -7.62082799e-02, -3.36184830e-02,
-6.67680635e-02,
5.45699044e-02, -1.87592767e-02, -6.99167346e-02,
4.65241463e-02, -6.59350614e-02, -7.56088775e-02,
3.15361531e-02, -5.38594563e-02, -4.96236942e-02,
6.34708473e-02, -5.73262117e-02, -6.15531976e-02,
6.96234575e-02, 8.11132388e-02, -1.33483670e-02,
-4.48100227e-02, 6.37322535e-02, -4.52916426e-02,
-3.30371851e-02, 6.74386787e-02, 4.08092220e-02,
-6.68296276e-02, 3.54094145e-02, 5.90271594e-02,
3.82535754e-04, 5.08161802e-02, 2.08302154e-02,
6.87943488e-02, 5.25004456e-02, 7.89765067e-02,
-3.41330182e-02, -5.55523254e-02, 1.29398858e-02,
-2.17985403e-02, 6.10304037e-02, -3.59220962e-02,
6.92200297e-02, 5.97105693e-02, -5.20775731e-02,
-5.51627134e-02, 5.55891104e-02, 4.48433581e-02,
-7.62232848e-02, -4.91314930e-03, -1.78160669e-03,
3.69225463e-02, 1.12623152e-02, -7.56351335e-02,
-1.74817723e-02, 5.41501901e-02, 2.11139104e-02,
-8.06466491e-02, 5.20345667e-02, 1.66710886e-02,
-1.07074857e-02, -7.79945259e-02, 3.75235269e-04,
6.76436688e-02, -6.92465729e-02, 5.64057434e-02,
-2.89222157e-02],
[ 8.11550722e-02, 7.43116863e-02, 1.58173168e-01,
-1.40629439e-03, 4.36200221e-02, 1.07655607e-01,
-5.83007193e-02, 1.18635527e-01, 1.69504168e-01,
1.68610484e-02, 1.23847933e-01, 1.33690329e-01,
-1.87358599e-01, 1.63267544e-01, 1.19774295e-01,
-1.15172968e-01, -1.94442963e-01, -9.24493738e-03,
1.80302586e-02, -5.83937865e-02, 1.12031539e-01,
-2.52955876e-02, -6.67900834e-02, -9.48188536e-02,
9.15420487e-02, -1.24323235e-02, -1.33816180e-02,
-2.60406333e-02, -7.13527744e-02, -5.56514676e-02,
-2.75451882e-02, -1.12596856e-01, -1.12186973e-01,
1.02415459e-02, 6.18014176e-02, -2.71080141e-02,
3.78162038e-02, -1.76214842e-01, 8.75928814e-02,
-9.88144928e-02, -1.98723068e-01, 6.77737318e-02,
5.08736006e-02, -1.10722762e-01, -1.35978996e-01,
1.04906836e-01, 6.80351005e-02, 3.80700550e-02,
-1.24677339e-01, 2.79494677e-02, 7.53350869e-02,
4.04438897e-03, -8.88436089e-02, -5.09099258e-02,
1.79098775e-01, -5.71230146e-02, -2.34761260e-02,
6.44917070e-02, 2.55942849e-02, 4.01924562e-02,
-1.15922286e-01, 9.25515881e-02, -1.08269939e-01,
6.30401597e-02]
],
[
[ -7.95361546e-03, -3.50323075e-03, -6.96525594e-03,
5.68995549e-03, -1.95431787e-03, -7.29476511e-03,
4.84972775e-03, -6.87810536e-03, -7.89081980e-03,
3.28609339e-03, -5.61573469e-03, -5.17333733e-03,
6.62034891e-03, -5.97796273e-03, -6.41981856e-03,
7.26406760e-03, 8.46769610e-03, -1.39053642e-03,
-4.67080409e-03, 6.64768811e-03, -4.72107375e-03,
-3.44261228e-03, 7.03542422e-03, 4.25329666e-03,
-6.97169730e-03, 3.69000621e-03, 6.15573997e-03,
3.98474715e-05, 5.29786528e-03, 2.17012434e-03,
7.17729097e-03, 5.47377395e-03, 8.24370014e-03,
-3.55688907e-03, -5.79259882e-03, 1.34797903e-03,
-2.27103686e-03, 6.36515775e-03, -3.74347783e-03,
7.22184259e-03, 6.22717756e-03, -5.42960529e-03,
-5.75189085e-03, 5.79644233e-03, 4.67428345e-03,
-7.95518750e-03, -5.11790397e-04, -1.85584202e-04,
3.84782832e-03, 1.17320675e-03, -7.89357036e-03,
-1.82120126e-03, 5.64610709e-03, 2.19968876e-03,
-8.41877646e-03, 5.42511341e-03, 1.73673068e-03,
-1.11540512e-03, -8.14078337e-03, 3.90870043e-05,
7.05687415e-03, -7.22462067e-03, 5.88177293e-03,
-3.01356178e-03],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00]
]
]
shape=(2,2,64)

‘last_states’:
[
[ 8.11550722e-02, 7.43116863e-02, 1.58173168e-01,
-1.40629439e-03, 4.36200221e-02, 1.07655607e-01,
-5.83007193e-02, 1.18635527e-01, 1.69504168e-01,
1.68610484e-02, 1.23847933e-01, 1.33690329e-01,
-1.87358599e-01, 1.63267544e-01, 1.19774295e-01,
-1.15172968e-01, -1.94442963e-01, -9.24493738e-03,
1.80302586e-02, -5.83937865e-02, 1.12031539e-01,
-2.52955876e-02, -6.67900834e-02, -9.48188536e-02,
9.15420487e-02, -1.24323235e-02, -1.33816180e-02,
-2.60406333e-02, -7.13527744e-02, -5.56514676e-02,
-2.75451882e-02, -1.12596856e-01, -1.12186973e-01,
1.02415459e-02, 6.18014176e-02, -2.71080141e-02,
3.78162038e-02, -1.76214842e-01, 8.75928814e-02,
-9.88144928e-02, -1.98723068e-01, 6.77737318e-02,
5.08736006e-02, -1.10722762e-01, -1.35978996e-01,
1.04906836e-01, 6.80351005e-02, 3.80700550e-02,
-1.24677339e-01, 2.79494677e-02, 7.53350869e-02,
4.04438897e-03, -8.88436089e-02, -5.09099258e-02,
1.79098775e-01, -5.71230146e-02, -2.34761260e-02,
6.44917070e-02, 2.55942849e-02, 4.01924562e-02,
-1.15922286e-01, 9.25515881e-02, -1.08269939e-01,
6.30401597e-02],
[ -7.95361546e-03, -3.50323075e-03, -6.96525594e-03,
5.68995549e-03, -1.95431787e-03, -7.29476511e-03,
4.84972775e-03, -6.87810536e-03, -7.89081980e-03,
3.28609339e-03, -5.61573469e-03, -5.17333733e-03,
6.62034891e-03, -5.97796273e-03, -6.41981856e-03,
7.26406760e-03, 8.46769610e-03, -1.39053642e-03,
-4.67080409e-03, 6.64768811e-03, -4.72107375e-03,
-3.44261228e-03, 7.03542422e-03, 4.25329666e-03,
-6.97169730e-03, 3.69000621e-03, 6.15573997e-03,
3.98474715e-05, 5.29786528e-03, 2.17012434e-03,
7.17729097e-03, 5.47377395e-03, 8.24370014e-03,
-3.55688907e-03, -5.79259882e-03, 1.34797903e-03,
-2.27103686e-03, 6.36515775e-03, -3.74347783e-03,
7.22184259e-03, 6.22717756e-03, -5.42960529e-03,
-5.75189085e-03, 5.79644233e-03, 4.67428345e-03,
-7.95518750e-03, -5.11790397e-04, -1.85584202e-04,
3.84782832e-03, 1.17320675e-03, -7.89357036e-03,
-1.82120126e-03, 5.64610709e-03, 2.19968876e-03,
-8.41877646e-03, 5.42511341e-03, 1.73673068e-03,
-1.11540512e-03, -8.14078337e-03, 3.90870043e-05,
7.05687415e-03, -7.22462067e-03, 5.88177293e-03,
-3.01356178e-03]
shape=(2,64)

猜你喜欢

转载自blog.csdn.net/qq_23142123/article/details/78486303