tensorflow:rnn中的output和state

原文:

https://xdrush.github.io/2018/02/12/RNN%E5%8E%9F%E7%90%86%E8%AF%A6%E8%A7%A3%E4%BB%A5%E5%8F%8Atensorflow%E4%B8%AD%E7%9A%84RNN%E5%AE%9E%E7%8E%B0/

关于output和state

个人认为,RNN最难理解的地方之一就是output和state,output对应的是输出,state对应的是状态,在tensorflow中,dynamic_rnn、static_rnn、bidirectional_dynamic_rnn、static_bidirectional_rnn都是返回(outputs, last_states)元组,注意,last_states是最终的状态,而outputs对应的则是每个时刻的输出。在使用tensorflow做RNN相关任务时,这一点不理解清楚后面就没法儿继续了。

output和state在RNN及其变体中的意义是不一样的,所表示的值也不一样,下面来看下几个最基本的RNN及其变体中的output和state的含义:

BasicRNNCell
基本的RNN结构如下所示:
基本的RNN
在基本的RNN结构中,我们可以认为输出就等于隐层状态值。我们来看下以下代码的outputs和last_states的值:

 
     
1
2
3
4
5
6
7
8
9
10
11
12
13
 
     
def dynamic_rnn_test():
BATCH_SIZE = 2
EMBEDDING_DIM = 4
X = np.random.randn(BATCH_SIZE, 5, EMBEDDING_DIM)
X_lengths = [ 5, 5]
cell = tf.nn.rnn_cell.BasicRNNCell(num_units= 10)
outputs, last_states = tf.nn.dynamic_rnn(cell=cell, dtype=tf.float64, sequence_length=X_lengths, inputs=X)
result = tf.contrib.learn.run_n({ "outputs": outputs, "last_states": last_states}, n= 1, feed_dict= None)
print (result[ 0][ "outputs"])
print ( "--------------------------------")
print (result[ 0][ "last_states"])

上面代码的输出结果为:

 
     
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
 
     
## outputs结果
[[[-0.26285439 -0.73199998 0.67373167 -0.42019807 -0.76447828 -0.15671307 0.19419611 -0.06485997 -0.59310542 0.41760793]
[-0.51952513 0.61765864 0.54485767 0.35961272 0.09553398 0.68890209 -0.46678386 0.34405317 0.8904701 -0.04432281]
[ 0.96647506 -0.50980204 0.55754585 0.93328233 0.57254379 0.6663917 -0.40768854 0.86358991 -0.58068622 -0.72018298]
[ 0.3345003 -0.09220678 0.69535521 -0.01648253 -0.21293752 -0.12114425 0.14904557 0.59020341 0.3342177 0.25945014]
[ 0.05128395 0.86625483 0.28549682 0.76454802 0.44757274 0.691485 0.00960586 0.23504622 0.75175537 -0.33478982]]
[[-0.27505881 0.78801392 -0.92769186 0.38675853 0.31331528 -0.79453833 0.77526593 -0.34045865 0.52494778 0.08722081]
[ 0.21659185 -0.05254756 -0.46941906 -0.70990551 0.82241305 0.7653751 -0.75469825 -0.65669409 -0.68308972 -0.54132448]
[ 0.6928769 -0.80066683 0.02133818 -0.66396161 -0.48229484 -0.80333658 0.66119584 0.79458079 -0.73295564 -0.65123496]
[-0.84663 0.26150571 -0.35573722 -0.88728337 -0.70946976 -0.59880986 0.95380342 0.63640031 0.14041671 -0.74008235]
[-0.49611388 -0.6615701 -0.91717102 -0.7921021 0.19823286 -0.52368639 0.73433595 -0.42381531 -0.22037713 -0.6572696 ]]]
--------------------------------
## last_states结果
[[ 0.05128395 0.86625483 0.28549682 0.76454802 0.44757274 0.691485 0.00960586 0.23504622 0.75175537 -0.33478982]
[-0.49611388 -0.6615701 -0.91717102 -0.7921021 0.19823286 -0.52368639 0.73433595 -0.42381531 -0.22037713 -0.6572696 ]]

比较下outputs[0][4](第一个样本最后时刻的输出)和last_states[0](第一个样本最后的状态)、以及outputs[1][4](第二个样本最后时刻的输出)和last_states[1](第二个样本最后时刻的输出)的值,不难发现,它们是相等的!这也印证上面的说法。

BasicLSTMCell
LSTM与基本的RNN有些不用(参见1.3节),因为LSTM引入了4个门,多了几个状态,因此LSTM的输出和BasicRNNCell是不同的。我们通过一个例子看看BasicLSTMCell的基本用法:

 
     
1
2
3
4
5
6
7
8
9
10
11
12
13
 
     
BATCH_SIZE = 2
EMBEDDING_DIM = 4
X = np.random.randn(BATCH_SIZE, 5, EMBEDDING_DIM)
X_lengths = [ 5, 5]
## 使用LSTM
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units= 10)
outputs, last_states = tf.nn.dynamic_rnn(cell=cell, dtype=tf.float64, sequence_length=X_lengths, inputs=X)
result = tf.contrib.learn.run_n({ "outputs": outputs, "last_states": last_states}, n= 1, feed_dict= None)
print (result[ 0][ "outputs"])
print ( "--------------------------------")
print (result[ 0][ "last_states"])

以上运行的结果为:

 
     
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 
     
[[[-0.02559858 0.06167649 0.1645186 0.05837131 0.01109476 -0.02938359 0.11438462 -0.0655154 0.05643917 -0.17126968]
[ 0.03442813 0.01858283 0.19324815 -0.08899226 0.03568535 -0.04120232 0.1620414 -0.12981834 0.13737953 -0.10649411]
[ 0.021796 -0.0292876 0.04972559 -0.04365079 0.06611464 -0.0123974 0.04471634 -0.09371935 0.07161399 -0.00129043]
[-0.06856613 0.08481594 0.08859627 -0.07172004 -0.02254162 0.04920269 0.06426967 -0.07178349 0.06880909 -0.03122769]
[-0.05681122 0.1265717 0.08145183 -0.10992898 -0.04531312 0.08419307 0.05815578 -0.03600487 0.06829341 -0.00815202]]
[[-0.03048013 -0.05028687 0.04530328 -0.01116215 -0.00322128 -0.0376331 0.05989264 -0.1386925 -0.02739475 -0.0416665 ]
[-0.07246373 0.00922893 -0.02089626 0.12696067 0.05484725 -0.05276134 0.02418303 -0.0003094 -0.04619291 -0.02940275]
[-0.06912543 0.06466857 0.22031627 -0.07334317 -0.03599558 0.01374829 0.12909539 -0.1685715 0.05465224 -0.19901284]
[-0.0769867 0.05043309 0.08731908 0.00185187 0.00557504 0.007338 0.0641817 -0.0849491 0.0245508 -0.07668919]
[-0.01582939 0.00979516 -0.02073626 0.09953952 0.10595823 -0.0135512 -0.12155518 0.04029387 0.00712342 0.02277357]]]
--------------------------------
LSTMStateTuple(
c=array([[-0.13211159, 0.26529373, 0.18125151, -0.19673843, -0.10883727, 0.16908338, 0.10463188, -0.08444297, 0.17317917, -0.01578971], [-0.03322975, 0.02126845, -0.04260041, 0.19423348, 0.22194511, -0.03170695, -0.19370151, 0.10526997, 0.0245572 , 0.05014028]]),
h=array([[-0.05681122, 0.1265717 , 0.08145183, -0.10992898, -0.04531312, 0.08419307, 0.05815578, -0.03600487, 0.06829341, -0.00815202], [-0.01582939, 0.00979516, -0.02073626, 0.09953952, 0.10595823, -0.0135512 , -0.12155518, 0.04029387, 0.00712342, 0.02277357]]))

从上面结果中我们看到,和BasicRNNCell相同的是,BasicLSTMCell返回的outputs是一样的,都是对应于每个时刻的输出(其实这里的输出也就是每个时刻的隐层状态值;更为一般的做法是,得到outputs值之后,在经过一个全连接层、softmax层做分类任务)。不同的是,last_states的值,BasicLSTMCell的last_states返回的是一个LSTMStateTuple,也就是一个LSTMState结构的元组,元组里面包含两个元素:c和h,c表示的就是最后时刻cell的内部状态值,h表示的就是最后时刻隐层状态值。

GRUCell
从1.4节中GRU原理可知,GRU的输出outputs和LSTM、BasicRNNCell是一样的,last_states和BasicRNNCell一样,只输出最后一个时刻的隐层状态值。同样用个例子来说明:

 
     
1
2
3
4
5
6
7
8
9
10
11
12
13
 
     
BATCH_SIZE = 2
EMBEDDING_DIM = 4
X = np.random.randn(BATCH_SIZE, 5, EMBEDDING_DIM)
X_lengths = [ 5, 5]
## cell用GRU
cell = tf.nn.rnn_cell.GRUCell(num_units= 10)
outputs, last_states = tf.nn.dynamic_rnn(cell=cell, dtype=tf.float64, sequence_length=X_lengths, inputs=X)
result = tf.contrib.learn.run_n({ "outputs": outputs, "last_states": last_states}, n= 1, feed_dict= None)
print (result[ 0][ "outputs"])
print ( "--------------------------------")
print (result[ 0][ "last_states"])

输出结果为:

 
     
1
2
3
4
5
6
7
8
9
10
11
12
13
14
 
     
[[[ 0.2818741 0.03127117 -0.10587379 0.04028049 0.10053002 0.15848186 -0.18849411 0.27622443 0.38123248 -0.13761087]
[ 0.20203697 0.27380701 0.20594786 0.32964536 -0.03476539 0.0324929 -0.17276558 0.23946512 0.25474486 -0.03569277]
[ 0.09995877 0.03133022 0.03788231 0.33481101 0.05394468 0.17044128 -0.22957891 0.07784969 0.12172921 -0.11151596]
[-0.01079724 0.34425545 0.36282874 0.51701521 -0.13545613 0.20845521 -0.16279659 0.08200397 -0.07883915 -0.0671937 ]
[ 0.04381321 0.10883886 0.37020907 0.42074759 0.14924879 0.07081199 -0.20527748 -0.0342331 -0.01571459 0.01904762]]
[[ 0.08714771 0.49216403 0.23638074 0.54007724 -0.12808233 0.05203507 0.04589614 0.20300933 0.00669649 -0.08931576]
[ 0.15230049 0.31014089 0.25244098 0.44602376 -0.04282711 0.13599053 -0.01098503 0.14189271 0.04150135 -0.06910757]
[ 0.3996701 0.10472691 0.21537184 0.39543418 0.22428281 0.07584328 -0.20120173 0.10623939 0.26915325 -0.09094824]
[ 0.38323232 0.09812629 0.04226342 0.37831236 0.27365562 0.20740802 -0.24894298 0.1094313 0.2308372 -0.12473171]
[ 0.27563199 0.01112365 0.06366856 0.41799209 0.45473254 0.27676832 -0.34215252 0.0085023 0.23020847 -0.23767658]]]
--------------------------------
[[ 0.04381321 0.10883886 0.37020907 0.42074759 0.14924879 0.07081199 -0.20527748 -0.0342331 -0.01571459 0.01904762]
[ 0.27563199 0.01112365 0.06366856 0.41799209 0.45473254 0.27676832 -0.34215252 0.0085023 0.23020847 -0.23767658]]

猜你喜欢

转载自blog.csdn.net/qq_25987491/article/details/80778645
今日推荐