前言
本系列主要主要是记录下Tensorflow在RNN实现这一块的相关代码,不做详细解释,主要是翻译加笔记。
RNNCell
在Tensorflow中,定义了一个RNNCell的抽象类,具体的所有不同类型的RNN Cell都是基于这个类的,所以就首先讲一下这个,下面是基本的代码:
class RNNCell(object):
def __call__(self, inputs, state, scope=None):
raise NotImplementedError("Abstract method")
@property
def state_size(self):
raise NotImplementedError("Abstract method")
@property
def output_size(self):
raise NotImplementedError("Abstract method")
def zero_state(self, batch_size, dtype):
state_size = self.state_size
if nest.is_sequence(state_size):
state_size_flat = nest.flatten(state_size)
zeros_flat = [
array_ops.zeros(
array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
dtype=dtype)
for s in state_size_flat]
for s, z in zip(state_size_flat, zeros_flat):
z.set_shape(_state_size_with_prefix(s, prefix=[None]))
zeros = nest.pack_sequence_as(structure=state_size,
flat_sequence=zeros_flat)
else:
zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))
return zeros
在Tensorflow中,Cell的定义不同于其他资料当中的定义,在其他的文档中Cell(下文指代为L-Cell)被看做是一个能够产生Single Scalar输出的对象,而在这里则是一个包含一系列L-Cell的水平数组。
具体到RNNCell,RNNCell是一个包含一个State(状态)并且能够执行一些处理输入矩阵的对象。RNNCell将输入的矩阵(Input Matrix),运算输出一个包含”self.output”列的输出矩阵(Ouput Matrix)。如果定义了“self.state_size”这个属性,并且取值为一个整数,那么RNNCell则会同时输出一个状态矩阵(State Matrix),包含“self.state_size”列。而如果“self.state_size”定义为一个整数的Tuple,,那么则是输出对应长度的状态矩阵的Tuple,Tuple中的每一个状态矩阵长度还是和之前的一样,包含“self.state_size”列。
在Tensorflow中,将会基于整个RNNCell实现一系列常用的RNNCell,比如LSTM和GRU,并且将会支持包含Dropout等在内的特性,同时也支持构建多层的RNN网络。
RNNCell基本结构
RNNCell有一些基本的属性需要设置:
state_size: 说明这个Cell使用的State的大小
output_size: 这个RNNCell最后生成结果的大小
对于每一个RNNCell的具体实现类,都必须要实现__call__这个方法:
每一个具体的RNN类必须实现的方法:
def __call__(self, inputs, state, scope=None):
这个方法是RNNCell的核心方法,其需要的属性有:
inputs: 这个需要输入一个形状为[batch_size,input_size]的2D Tensor,batch_size是你训练中指定的batch的大小,而input_size则是输入数据的维度
state: state就是你rnn网络中rnn cell的状态,比如说如果你的rnn定义包含了N个单元(也就是你的self.state_size是个整数N),那么在你每次执行RNN网络时就应该给一个[batch_size,self.state_size]形状的2D Tensor来表示当前RNN网络的状态,而如果你的self.state_size是一个元祖,那么给定的状态也应该是一个Tuple,每个Tuple里的状态表示和之前的方式一样,只要注意好不同的self.state_size取值就好
而RNN Cell经过一系列的工作后,将会输出如下的东西:
output:对应的你的batch的大小和输出大小的结果,形状是[batch_size x self.output_size]
state:根据你的self.state_size的不同,输出一个更新后的RNN状态,或者一个Tuple的状态,格式对应输入的state
同时RNNCell还定义了一个非抽象的方法,那就是生成初始化状态的方法,比较简单就不多说了:
def zero_state(self, batch_size, dtype):
BasicRNNCell
下面介绍完了RNNCell的定义,我们来看一个最原始的RNN的实现,就是不涉及到LSTM,GRU的那种。这种RNNCell被称作BasicRNNCell,代码很简短:
class BasicRNNCell(RNNCell):
"""The most basic RNN cell."""
def __init__(self, num_units, input_size=None, activation=tanh):
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._activation = activation
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
output = self._activation(_linear([inputs, state], self._num_units, True))
return output, output
在最基本的RNN实现当中,RNN在时间t的输出,就是其在时间t的状态
output = new_state = activation(W * input + U * state + B)
这个计算就直接在__call__中计算完成了,这个函数比较简单,但是他具体如何计算则调用了一个方法,不在类中,那么我们看看这个函数先:
_linear([inputs, state], self._num_units, True)
对应函数介绍,_liner的功能就是你给了一个或一系列的Tensor(A,B,C.....),他给你计算一个W1*A+W2*B.....+Bias的结果存在,比如输入[input,state],那么这个方法就是计算W * input + U * state:
def _linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
Args:
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_start: starting value to initialize the bias; 0 by default.
scope: VariableScope for the created subgraph; defaults to "Linear".
Returns:
A 2D Tensor with shape [batch x output_size] equal to
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
到此,关于Tensorflow里面RNNCell的基本结构,以及BasicRNNCell的源码分析结束。
以上,MebiuW