Tensorflow RNN源代码解析笔记1:RNNCell的基本实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/MebiuW/article/details/60780813

前言

本系列主要主要是记录下Tensorflow在RNN实现这一块的相关代码,不做详细解释,主要是翻译加笔记。

RNNCell

在Tensorflow中,定义了一个RNNCell的抽象类,具体的所有不同类型的RNN Cell都是基于这个类的,所以就首先讲一下这个,下面是基本的代码:

class RNNCell(object):
  def __call__(self, inputs, state, scope=None):
    raise NotImplementedError("Abstract method")

  @property
  def state_size(self):
    raise NotImplementedError("Abstract method")

  @property
  def output_size(self):
    raise NotImplementedError("Abstract method")

  def zero_state(self, batch_size, dtype):
    state_size = self.state_size
    if nest.is_sequence(state_size):
      state_size_flat = nest.flatten(state_size)
      zeros_flat = [
          array_ops.zeros(
              array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
              dtype=dtype)
          for s in state_size_flat]
      for s, z in zip(state_size_flat, zeros_flat):
        z.set_shape(_state_size_with_prefix(s, prefix=[None]))
      zeros = nest.pack_sequence_as(structure=state_size,
                                    flat_sequence=zeros_flat)
    else:
      zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
      zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
      zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))

    return zeros

在Tensorflow中,Cell的定义不同于其他资料当中的定义,在其他的文档中Cell(下文指代为L-Cell)被看做是一个能够产生Single Scalar输出的对象,而在这里则是一个包含一系列L-Cell的水平数组。

具体到RNNCell,RNNCell是一个包含一个State(状态)并且能够执行一些处理输入矩阵的对象。RNNCell将输入的矩阵(Input Matrix),运算输出一个包含”self.output”列的输出矩阵(Ouput Matrix)。如果定义了“self.state_size”这个属性,并且取值为一个整数,那么RNNCell则会同时输出一个状态矩阵(State Matrix),包含“self.state_size”列。而如果“self.state_size”定义为一个整数的Tuple,,那么则是输出对应长度的状态矩阵的Tuple,Tuple中的每一个状态矩阵长度还是和之前的一样,包含“self.state_size”列。

在Tensorflow中,将会基于整个RNNCell实现一系列常用的RNNCell,比如LSTM和GRU,并且将会支持包含Dropout等在内的特性,同时也支持构建多层的RNN网络。

RNNCell基本结构

RNNCell有一些基本的属性需要设置:

state_size: 说明这个Cell使用的State的大小
output_size: 这个RNNCell最后生成结果的大小

对于每一个RNNCell的具体实现类,都必须要实现__call__这个方法:

每一个具体的RNN类必须实现的方法:
def __call__(self, inputs, state, scope=None):

这个方法是RNNCell的核心方法,其需要的属性有:

inputs: 这个需要输入一个形状为[batch_size,input_size]的2D Tensor,batch_size是你训练中指定的batch的大小,而input_size则是输入数据的维度

state: state就是你rnn网络中rnn cell的状态,比如说如果你的rnn定义包含了N个单元(也就是你的self.state_size是个整数N),那么在你每次执行RNN网络时就应该给一个[batch_size,self.state_size]形状的2D Tensor来表示当前RNN网络的状态,而如果你的self.state_size是一个元祖,那么给定的状态也应该是一个Tuple,每个Tuple里的状态表示和之前的方式一样,只要注意好不同的self.state_size取值就好

而RNN Cell经过一系列的工作后,将会输出如下的东西:

output:对应的你的batch的大小和输出大小的结果,形状是[batch_size x self.output_size]

state:根据你的self.state_size的不同,输出一个更新后的RNN状态,或者一个Tuple的状态,格式对应输入的state

同时RNNCell还定义了一个非抽象的方法,那就是生成初始化状态的方法,比较简单就不多说了:

 def zero_state(self, batch_size, dtype):

BasicRNNCell

这里写图片描述
下面介绍完了RNNCell的定义,我们来看一个最原始的RNN的实现,就是不涉及到LSTM,GRU的那种。这种RNNCell被称作BasicRNNCell,代码很简短:

class BasicRNNCell(RNNCell):
  """The most basic RNN cell."""

  def __init__(self, num_units, input_size=None, activation=tanh):
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated.", self)
    self._num_units = num_units
    self._activation = activation

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def __call__(self, inputs, state, scope=None):
    """Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
    with vs.variable_scope(scope or type(self).__name__):  # "BasicRNNCell"
      output = self._activation(_linear([inputs, state], self._num_units, True))
    return output, output

在最基本的RNN实现当中,RNN在时间t的输出,就是其在时间t的状态

output = new_state = activation(W * input + U * state + B)

这个计算就直接在__call__中计算完成了,这个函数比较简单,但是他具体如何计算则调用了一个方法,不在类中,那么我们看看这个函数先:

_linear([inputs, state], self._num_units, True)

对应函数介绍,_liner的功能就是你给了一个或一系列的Tensor(A,B,C.....),他给你计算一个W1*A+W2*B.....+Bias的结果存在,比如输入[input,state],那么这个方法就是计算W * input + U * state:
def _linear(args, output_size, bias, bias_start=0.0, scope=None):
  """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

  Args:
    args: a 2D Tensor or a list of 2D, batch x n, Tensors.
    output_size: int, second dimension of W[i].
    bias: boolean, whether to add a bias term or not.
    bias_start: starting value to initialize the bias; 0 by default.
    scope: VariableScope for the created subgraph; defaults to "Linear".

  Returns:
    A 2D Tensor with shape [batch x output_size] equal to
    sum_i(args[i] * W[i]), where W[i]s are newly created matrices.

  Raises:
    ValueError: if some of the arguments has unspecified or wrong shape.

到此,关于Tensorflow里面RNNCell的基本结构,以及BasicRNNCell的源码分析结束。
以上,MebiuW

猜你喜欢

转载自blog.csdn.net/MebiuW/article/details/60780813