利用tf.nn.raw_rnn自定义RNN Decoder
在用 Tensorflow 实现 RNN AutoEncoder 时,要求 Decoder 部分每个 timestep 都使用它上一个 timestep 的输出 作为输入,而利用普通的 tf.nn.dynamic_rnn 无法做到这一点,因为它必须接收序列为输入,对于序列的每一个 timestep 都计算并给出对应的输出。
# 下为dynamic_rnn的经典用法
# seqs是序列,lengths是序列长度
stacked_rnn = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(N) for N in [64, 128]], state_is_tuple=True)
outputs, states = tf.nn.dynamic_rnn(stacked_rnn, seq, lengths, dtype=tf.float32)
解决办法是使用 tf.nn.raw_rnn,它在控制RNN循环上更加自由。
总之,本文的目的就是:
利用raw_rnn让RNN只接收自己上一时刻的输出作为输入。
tf.nn.raw_rnn
根据 Tensorflow API,函数原型如下:
tf.nn.raw_rnn(
cell,
loop_fn,
parallel_iterations=None,
swap_memory=False,
scope=None
)
它内部实际上在做的事情,用伪代码表达如下:
time = tf.constant(0, dtype=tf.int32)
(finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
while not all(finished):
(output, cell_state) = cell(next_input, state)
(next_finished, next_input, next_state, emit, loop_state) = loop_fn(
time=time + 1, cell_output=output, cell_state=cell_state,
loop_state=loop_state)
# Emit zeros and copy forward state for minibatch entries that are finished.
state = tf.where(finished, state, next_state)
emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
emit_ta = emit_ta.write(time, emit)
# If any new minibatch entries are marked as finished, mark these.
finished = tf.logical_or(finished, next_finished)
time += 1
return (emit_ta, state, loop_state)
用更加通俗易懂的语言描述如下:
- 利用 loop_fn 计算 time=0 时的一系列初始变量
- 进入循环,在batch内全部样本 “finish” 时才结束
- 循环内,由 cell 计算新的 output 和 state
- 循环内,由 loop_fn 处理 output 和 state,决定本轮最终的输出 emit 和状态 next_state,以及下一轮的输入 next_input
也就是说,实际上即使只有 cell 也能完成RNN的功能,但是 loop_fn 的存在允许做进一步处理。
值得注意的是,如果 cell 是多个的(tuple型),那么这里最终的 emit_ta 和 state 也将是复数的。
另外,Tensorflow 中用 LSTMStateTuple 来存储 LSTM 状态,它包含 c (状态)和 h (输出),这里的 h 实际上正是 dynamic_rnn
返回的 output
中的最后一个timestep的值,更多详情参考 API。
用tf.nn.raw_rnn实现tf.nn.dynamic_rnn
由于 dynamic_rnn 以序列为输入,因此在每个 timestep 需要 loop_fn 指定 next_input 直接从输入序列读取,由于无须用到输出,因此让 emit_output = cell_output 了。
inputs = tf.placeholder(shape=(max_time, batch_size, input_depth),
dtype=tf.float32)
sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32)
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
inputs_ta = inputs_ta.unstack(inputs)
cell = tf.contrib.rnn.LSTMCell(num_units)
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output # == None for time == 0
if cell_output is None: # time == 0
next_cell_state = cell.zero_state(batch_size, tf.float32)
else:
next_cell_state = cell_state
elements_finished = (time >= sequence_length)
finished = tf.reduce_all(elements_finished)
next_input = tf.cond(finished,
lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
lambda: inputs_ta.read(time))
next_loop_state = None
return (elements_finished, next_input, next_cell_state,
emit_output, next_loop_state)
outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()
用tf.nn.raw_rnn实现tf.nn.dynamic_rnn
Decoder一般与Encoder搭配着使用,所以这里先实现一个Encoder,让其返回输出与隐藏状态。
class encoder_naive(object):
'''naive implementation using dynamic_rnn'''
def __init__(self, hidden_units):
self.name = 'naive_lstm_encoder'
self.hidden_units = hidden_units # list
with tf.variable_scope(self.name):
encoder_rnn_layers = [tf.nn.rnn_cell.LSTMCell(size, use_peepholes=False) for size in self.hidden_units]
self.encoder_multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(encoder_rnn_layers, state_is_tuple=True)
def __call__(self, x, lens):
with tf.variable_scope(self.name):
encoder_outputs, encoder_last_state = tf.nn.dynamic_rnn(cell=self.encoder_multi_rnn_cell,
time_major=True,
inputs=x, sequence_length=lens, dtype=tf.float32)
return encoder_outputs, encoder_last_state
@property
def vars(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
针对这样的Encoder,设计Decoder,让Decoder的第一层以Encoder最后一个时刻的隐藏状态为初始状态,其余层以零值为初始状态。特殊情况下,Encoder和Decoder都只有一层,则Decoder继承Encoder的全部隐藏状态。
class dynamic_decoder(object):
'''
Implementation using raw_rnn.
Refering to https://github.com/ematvey/tensorflow-seq2seq-tutorials/blob/master/2-seq2seq-advanced.ipynb
'''
def __init__(self, hidden_units, output_depth):
self.name = 'dynamic_lstm_decoder'
self.hidden_units = hidden_units
self.output_depth = output_depth
# Output Projection
with tf.variable_scope(self.name):
self.W = tf.Variable(tf.random_uniform([self.hidden_units[-1], self.output_depth], -1.0, 1.0), dtype=tf.float32)
self.b = tf.Variable(tf.zeros([self.output_depth]), dtype=tf.float32)
self.decoder_rnn_layers = [tf.nn.rnn_cell.LSTMCell(size, use_peepholes=False) for size in self.hidden_units]
self.decoder_multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(self.decoder_rnn_layers)
def __call__(self, encoder_final_state, lens):
# for now just assume that 'lens' is available
# 'encoder_final_state' should be either a tuple of LSTMStateTuple or a LSTMStateTuple, element size [bz, H]
# batch_size, _ = encoder_final_state.c.shape
batch_size, _ = tf.unstack(tf.shape(encoder_final_state.c))
print batch_size
eos_time_slice = tf.ones([batch_size, self.output_depth], dtype=tf.float32, name='EOS')
pad_time_slice = tf.zeros([batch_size, self.output_depth], dtype=tf.float32, name='PAD')
def loop_fn_initial():
# used when time = 0
initial_elements_finished = (lens <= 0) # all False
initial_input = eos_time_slice
'''here's the key link of decoder's last state and decoder's inital state'''
initial_cell_state = []
initial_cell_state.append(encoder_final_state) # [bz, H1]
for i in range(1, len(self.hidden_units)):
initial_cell_state.append(self.decoder_rnn_layers[i].zero_state(batch_size, dtype=tf.float32))
return (initial_elements_finished,
initial_input,
tuple(initial_cell_state),
None, None)
def loop_fn_transition(time, cell_output, cell_state, loop_state):
def get_next_input():
return tf.add(tf.matmul(cell_output, self.W), self.b)
elements_finished = (lens <= time)
finished = tf.reduce_all(elements_finished)
inputs = tf.cond(finished, lambda:pad_time_slice, get_next_input)
states = cell_state
outputs = cell_output
loop_state = None
return (elements_finished,
inputs,
states,
outputs,
loop_state)
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_state is None:
assert cell_state is None and cell_output is None
return loop_fn_initial()
else:
return loop_fn_transition(time, cell_output, cell_state, loop_state)
# core codes
with tf.variable_scope(self.name):
decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(self.decoder_multi_rnn_cell, loop_fn)
decoder_outputs = decoder_outputs_ta.stack()
return decoder_outputs, decoder_final_state
@property
def vars(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
考虑到Encoder的最后一层和Decoder的第一层神经元个数可能不一样,在传递隐藏状态前,还需要经过一个矩阵变换(比如不带激活函数的全连接层)。最终Encoder-Decoder结构如下,再针对decoder_y
和logits
之间设计Loss函数即可,比如MSE
。
# hyper-parameters
max_time = 100
batch_size = 128
input_depth = 6
output_depth = 3
encoder_hidden_units = [32,64,128]
decoder_hidden_units = [50, 100]
encoder_x = tf.placeholder(shape=(max_time, batch_size, input_depth), dtype=tf.float32, name='encoder_inputs') # time-major
lens = tf.placeholder(shape=(batch_size, ), dtype=tf.int32, name='lengths')
decoder_y = tf.placeholder(shape=(max_time+1, batch_size, output_depth), dtype=tf.float32, name='decoder_outputs') # time-major
encoder = encoder_naive(hidden_units=encoder_hidden_units)
decoder = dynamic_decoder(hidden_units=decoder_hidden_units, output_depth=output_depth)
W_proj = tf.Variable(tf.random_uniform([encoder_hidden_units[-1], decoder_hidden_units[0]], -1.0, 1.0), dtype=tf.float32)
b_proj = tf.Variable(tf.zeros([decoder_hidden_units[0]]), dtype=tf.float32)
# encoder-decoder
encoder_outputs, encoder_state = encoder(encoder_x, lens)
initial_state = tf.nn.rnn_cell.LSTMStateTuple(
c=tf.add(tf.matmul(encoder_state[-1].c, W_proj), b_proj),
h=tf.add(tf.matmul(encoder_state[-1].h, W_proj), b_proj)
)
decoder_outputs, decoder_state = decoder(initial_state, lens + 1)
logits = tf.layers.dense(decoder_outputs, output_depth)
参考资料
【Paper】Semi-supervised Sequence Learning
【Paper】Sequence to Sequence Learning with Neural Networks
【API】Tensorflow dynamic_rnn
【API】Tensorflow raw_rnn
【github】tensorflow-seq2seq-tutorials