TensorFlow----实现 RNN 的基本单元: RNNCell

RNNCell是TensorFlow中发RNN基本单元。

本身是一个抽象类,拥有两个子类,一个是BasicRNNCell,另一个是BasicLSTMCell。

 (注:RNNCell:是抽象类不能进行实例化,可以使用它的子类 BasicRNNCell 或BasicLSTMCell 进行实例化,得到 cell )

RNNCell的三个要点

  1. 类方法call (实现单步循环)
  2. 类属性 state_size
  3. 类属性 output_size 

call方法,所有RNNCell的子类都会实现一个call函数。利用call函数可以实现RNN的单步计算。

对于一个已经实例化好的基本单元cell 调用形式为: 

 (output, next_ state) = cell.call(input, state) 

RNNCell的类属性state_size和output_size分别规定了隐层的大小和输出向量的大小。

通常是以batch形式输入数据,input的形状为(batch_size,input_size),

调用call函数时对应的隐层的形状是(batch_size,state_size),

输出的形状是(batch_size,output_size)。


在TensorFlow中定义一个基本的RNN单元的方法为:

import tensorflow as tf
rnn_cell= tf.nn.rnn_cell.BasicRNNCell(num units=128)
print(rnn_cell.state_size ) #打出 state_size 着一下,应有 state_size = 128

在TensorFlow中定义一个LSTM基本单元的方法为:

import tensorflow as t f
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num units=128)
print(lstm_cell.state_size)   #state_size = LSTMStateTuple(c=l28, h=128)

LSTM可以看做有h和C两个隐层。在TensorFlow中LSTM基本单元的state_size由两部分组成,一部分是c,另一部分是h。

具体使用时,可以通过state.h以及state.c进行访问,下面是一个示例代码:

import tensorflow as i::f
import numpy as np
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

#通过zero_state方法得到一个全0的初始化状态
hO = lstm_cell.zero_state(32, np.float32) 

#调用call方法实现单步计算
output, hl = lstm_cell.call(inputs, hO)

#查看h1的状态
print(hl.h) # shape=(32, 128)
print(hl.c) # shape=(32, 128)

堆叠RNN : MultiRNNCell

单层RNN能力有限,需要多层RNN。

将x输入到第一层RNN后得到隐层状态h,这个隐层状态相当于第二层RNN的输入,第二层RNN的隐层状态又相当于第三层RNN的输入,以此类推。三层RNN串联

在TensorFlow中,使用tf.nn.rnn_cell.MultiRNNCell函数对RNN进行堆叠,代码如下:

import tensorflow as tf
import numpy as np

#每次调用这个函数返回一个BasicRNNCell
def get_a_cell():
    return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

#用tf.nn.rnn _cell_MultiRNNCell创建三层RNN
cell= tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])

#得到的cell实际也是RNNCell的子类
#它的state_size是(128,128,128)代表3个隐层状态,每个隐层状态是128
print(cell.state_size)  # (128, 128, 128)

#使用对应的call函数
inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch size
#通过zero_state方法得到一个全0的初始状态
hO = cell.zero_state(32, np.float32)

output, hl = cell.call(inputs, hO) 
print(hl) # tuple 中合有 3 个 32xl28 的向噩

猜你喜欢

转载自blog.csdn.net/DeepOscar/article/details/81121828
今日推荐