利用tf.nn.raw_rnn自定义RNN Decoder

利用tf.nn.raw_rnn自定义RNN Decoder

在用 Tensorflow 实现 RNN AutoEncoder 时,要求 Decoder 部分每个 timestep 都使用它上一个 timestep 的输出 y t 1 y_{t-1} 作为输入,而利用普通的 tf.nn.dynamic_rnn 无法做到这一点,因为它必须接收序列为输入,对于序列的每一个 timestep 都计算并给出对应的输出。

# 下为dynamic_rnn的经典用法
# seqs是序列,lengths是序列长度
stacked_rnn = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(N) for N in [64, 128]], state_is_tuple=True)
outputs, states = tf.nn.dynamic_rnn(stacked_rnn, seq, lengths, dtype=tf.float32)

解决办法是使用 tf.nn.raw_rnn,它在控制RNN循环上更加自由。
总之,本文的目的就是:

利用raw_rnn让RNN只接收自己上一时刻的输出作为输入。


tf.nn.raw_rnn

根据 Tensorflow API,函数原型如下:

tf.nn.raw_rnn(
    cell,
    loop_fn,
    parallel_iterations=None,
    swap_memory=False,
    scope=None
)

它内部实际上在做的事情,用伪代码表达如下:

time = tf.constant(0, dtype=tf.int32)
(finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
	    time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
while not all(finished):
	(output, cell_state) = cell(next_input, state)
	(next_finished, next_input, next_state, emit, loop_state) = loop_fn(
		time=time + 1, cell_output=output, cell_state=cell_state,
        loop_state=loop_state)
    # Emit zeros and copy forward state for minibatch entries that are finished.
    state = tf.where(finished, state, next_state)
    emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
    emit_ta = emit_ta.write(time, emit)
    # If any new minibatch entries are marked as finished, mark these.
    finished = tf.logical_or(finished, next_finished)
    time += 1
return (emit_ta, state, loop_state)

用更加通俗易懂的语言描述如下:

  1. 利用 loop_fn 计算 time=0 时的一系列初始变量
  2. 进入循环,在batch内全部样本 “finish” 时才结束
  3. 循环内,由 cell 计算新的 outputstate
  4. 循环内,由 loop_fn 处理 outputstate,决定本轮最终的输出 emit 和状态 next_state,以及下一轮的输入 next_input

也就是说,实际上即使只有 cell 也能完成RNN的功能,但是 loop_fn 的存在允许做进一步处理。
值得注意的是,如果 cell 是多个的(tuple型),那么这里最终的 emit_tastate 也将是复数的。

另外,Tensorflow 中用 LSTMStateTuple 来存储 LSTM 状态,它包含 c (状态)和 h (输出),这里的 h 实际上正是 dynamic_rnn 返回的 output 中的最后一个timestep的值,更多详情参考 API


用tf.nn.raw_rnn实现tf.nn.dynamic_rnn

由于 dynamic_rnn 以序列为输入,因此在每个 timestep 需要 loop_fn 指定 next_input 直接从输入序列读取,由于无须用到输出,因此让 emit_output = cell_output 了。

inputs = tf.placeholder(shape=(max_time, batch_size, input_depth),
                        dtype=tf.float32)
sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32)
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
inputs_ta = inputs_ta.unstack(inputs)

cell = tf.contrib.rnn.LSTMCell(num_units)

def loop_fn(time, cell_output, cell_state, loop_state):
    emit_output = cell_output  # == None for time == 0
    if cell_output is None:  # time == 0
        next_cell_state = cell.zero_state(batch_size, tf.float32)
    else:
        next_cell_state = cell_state
    elements_finished = (time >= sequence_length)
    finished = tf.reduce_all(elements_finished)
    next_input = tf.cond(finished,
      					 lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
      					 lambda: inputs_ta.read(time))
    next_loop_state = None
    return (elements_finished, next_input, next_cell_state,
            emit_output, next_loop_state)

outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()

用tf.nn.raw_rnn实现tf.nn.dynamic_rnn

在这里插入图片描述

Decoder一般与Encoder搭配着使用,所以这里先实现一个Encoder,让其返回输出与隐藏状态。

class encoder_naive(object):
    '''naive implementation using dynamic_rnn'''
    def __init__(self, hidden_units):
        self.name = 'naive_lstm_encoder'
        self.hidden_units = hidden_units # list
        with tf.variable_scope(self.name):
            encoder_rnn_layers = [tf.nn.rnn_cell.LSTMCell(size, use_peepholes=False) for size in self.hidden_units]
            self.encoder_multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(encoder_rnn_layers, state_is_tuple=True)
    def __call__(self, x, lens):
        with tf.variable_scope(self.name):
            encoder_outputs, encoder_last_state = tf.nn.dynamic_rnn(cell=self.encoder_multi_rnn_cell, 
                time_major=True,
                inputs=x, sequence_length=lens, dtype=tf.float32)

        return encoder_outputs, encoder_last_state

    @property
    def vars(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)

针对这样的Encoder,设计Decoder,让Decoder的第一层以Encoder最后一个时刻的隐藏状态为初始状态,其余层以零值为初始状态。特殊情况下,Encoder和Decoder都只有一层,则Decoder继承Encoder的全部隐藏状态。

class dynamic_decoder(object):
    '''
        Implementation using raw_rnn.
        Refering to https://github.com/ematvey/tensorflow-seq2seq-tutorials/blob/master/2-seq2seq-advanced.ipynb
    '''
    def __init__(self, hidden_units, output_depth):
        self.name = 'dynamic_lstm_decoder'
        self.hidden_units = hidden_units
        self.output_depth = output_depth

        # Output Projection
        with tf.variable_scope(self.name):
            self.W = tf.Variable(tf.random_uniform([self.hidden_units[-1], self.output_depth], -1.0, 1.0), dtype=tf.float32)
            self.b = tf.Variable(tf.zeros([self.output_depth]), dtype=tf.float32)

            self.decoder_rnn_layers = [tf.nn.rnn_cell.LSTMCell(size, use_peepholes=False) for size in self.hidden_units]
            self.decoder_multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(self.decoder_rnn_layers)

    def __call__(self, encoder_final_state, lens):
        # for now just assume that 'lens' is available 
        # 'encoder_final_state' should be either a tuple of LSTMStateTuple or a LSTMStateTuple, element size [bz, H]

        # batch_size, _ = encoder_final_state.c.shape
        batch_size, _ = tf.unstack(tf.shape(encoder_final_state.c))
        print batch_size
        eos_time_slice = tf.ones([batch_size, self.output_depth], dtype=tf.float32, name='EOS')
        pad_time_slice = tf.zeros([batch_size, self.output_depth], dtype=tf.float32, name='PAD')
        
        def loop_fn_initial():
            # used when time = 0
            initial_elements_finished = (lens <= 0) # all False
            initial_input = eos_time_slice
            
            '''here's the key link of decoder's last state and decoder's inital state'''
            initial_cell_state = []
            initial_cell_state.append(encoder_final_state) # [bz, H1]
            for i in range(1, len(self.hidden_units)):
                initial_cell_state.append(self.decoder_rnn_layers[i].zero_state(batch_size, dtype=tf.float32))
            
            return (initial_elements_finished, 
                    initial_input,
                    tuple(initial_cell_state),
                    None, None)

        def loop_fn_transition(time, cell_output, cell_state, loop_state):
            def get_next_input():
                return tf.add(tf.matmul(cell_output, self.W), self.b)
            elements_finished = (lens <= time)
            finished = tf.reduce_all(elements_finished)
            inputs = tf.cond(finished, lambda:pad_time_slice, get_next_input)
            states = cell_state
            outputs = cell_output
            loop_state = None 
            return (elements_finished, 
                    inputs,
                    states,
                    outputs,
                    loop_state)

        def loop_fn(time, cell_output, cell_state, loop_state):
            if cell_state is None:
                assert cell_state is None and cell_output is None
                return loop_fn_initial()
            else:
                return loop_fn_transition(time, cell_output, cell_state, loop_state)
        
        # core codes
        with tf.variable_scope(self.name):
            decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(self.decoder_multi_rnn_cell, loop_fn)
            decoder_outputs = decoder_outputs_ta.stack()

        return decoder_outputs, decoder_final_state

    @property
    def vars(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)

考虑到Encoder的最后一层和Decoder的第一层神经元个数可能不一样,在传递隐藏状态前,还需要经过一个矩阵变换(比如不带激活函数的全连接层)。最终Encoder-Decoder结构如下,再针对decoder_ylogits之间设计Loss函数即可,比如MSE

# hyper-parameters
max_time = 100
batch_size = 128
input_depth = 6
output_depth = 3
encoder_hidden_units = [32,64,128]
decoder_hidden_units = [50, 100]

encoder_x = tf.placeholder(shape=(max_time, batch_size, input_depth), dtype=tf.float32, name='encoder_inputs') # time-major
lens = tf.placeholder(shape=(batch_size, ), dtype=tf.int32, name='lengths')
decoder_y = tf.placeholder(shape=(max_time+1, batch_size, output_depth), dtype=tf.float32, name='decoder_outputs') # time-major

encoder = encoder_naive(hidden_units=encoder_hidden_units)
decoder = dynamic_decoder(hidden_units=decoder_hidden_units, output_depth=output_depth)

W_proj = tf.Variable(tf.random_uniform([encoder_hidden_units[-1], decoder_hidden_units[0]], -1.0, 1.0), dtype=tf.float32)
b_proj = tf.Variable(tf.zeros([decoder_hidden_units[0]]), dtype=tf.float32)

# encoder-decoder
encoder_outputs, encoder_state = encoder(encoder_x, lens)
initial_state = tf.nn.rnn_cell.LSTMStateTuple(
  c=tf.add(tf.matmul(encoder_state[-1].c, W_proj), b_proj),
  h=tf.add(tf.matmul(encoder_state[-1].h, W_proj), b_proj)
)
decoder_outputs, decoder_state = decoder(initial_state, lens + 1)
logits = tf.layers.dense(decoder_outputs, output_depth)

参考资料

【Paper】Semi-supervised Sequence Learning
【Paper】Sequence to Sequence Learning with Neural Networks
【API】Tensorflow dynamic_rnn
【API】Tensorflow raw_rnn
【github】tensorflow-seq2seq-tutorials

猜你喜欢

转载自blog.csdn.net/songbinxu/article/details/83858097
今日推荐