论文链接:
https://arxiv.org/abs/1806.01822
源起:
将个体的经历和轨迹看作是一种记忆就可以定义记忆网络结构,其作用就是从这种经历出发推断可能的状态及导向性操作。(为达到某一状态需要进行的操作)
最简单的一种例子就是RNN,从seq2seq attention摘要的角度,context vector可以看作是对encoder端的解码相关记忆概括,相应叙述见前文:End-To-End Memory Networks 论文阅读。那篇文章已经提出一种模仿lstm记忆结构的记忆网络,在那里相应的记忆可以看成是一种来自外部的信息,即将其看成带context候选信息的QA问题,其记忆部分正对应context。类似的解决方案还有使用query直接从context中进行抽取的结构,如:BI-DIRECTIONAL ATTENSION FLOW FOR MACHINE COMPREHENSION 中的问题。
现在让我们换一个场景,比如多轮QA,从对话主体的角度,记忆应该对应于整个对话产生的上下文,而这上下文本身是对话主体产生的,而不是外部给定的,是一种动态记忆更新结构,本文就可以看成更新个体经历记忆的一种网络结构。
从基本神经网络来看,对这一类问题应该已经有可能的解决方案了,比如RNN中的隐状态应该就是正解。从概念上来说是对的,但是从网络容量及设计上还有一些问题,如对lstm相应的“记忆体”是一个向量,而且c与h又有其具体的意义,其个体统计意义更多的是单步的信息,下一步又仅依赖于上一步,我们需要一个“盒子”用来积攒往来所有的记忆(隐状态,或等价的输入),而不是对单个step的记忆更新来搞筛选。
所以很多论文的出发点就可以看成将lstm中的隐状态向量“改成”矩阵。非常推荐看一下下面这篇文章:Hybrid computing using a neural network with dynamic external memory
看完了这篇看其他的就不会有难度了,作者构造了记忆矩阵,并将其类比为内存块,还设计了内存的“读写头”(而且在这个过程还会使用eraser擦掉没有用的记忆),把网络看成cpu,很多学计算机的应该会很喜欢这种精细的设计。精细的设计在对具体子块进行验证时会有更好的意义。就是有一些复杂。更难得的是,优化还是强化学习。
相较上面那篇文章本文就简单得多,一句话来概括:Attention Is All You Need。由于从问题出发不需要记住位置,所用到的就是 multi-head self-attention。由这种网络结构直接构造记忆矩阵如下:
其单步更新方法如下:
将输入x直接“压缩”fuse到记忆中。
记忆的构造与更新相对简单,重要的是如何对于一个记忆矩阵定义诸如lstm那种带gate的记忆筛选“读取”方式,以进行动态存取与信息整合。文中给出的公式就比较复杂了,而且只有结合作者提供的代码才能了解细节:
实际在看代码之前,模型的整体图示更重要:
公式中唯一没有显式给出的就是函数g的定义,这也是个人认为最重要的,其用mlp做了row element wise的sum。
下面为了方便还是copy一下实现的代码:
from sonnet.python.modules import basic from sonnet.python.modules import layer_norm from sonnet.python.modules import rnn_core from sonnet.python.modules.nets import mlp import tensorflow as tf class RelationalMemory(rnn_core.RNNCore): def __init__(self, mem_slots = 10, head_size = 10, num_heads = 3, num_blocks = 1, forget_bias = 1.0, input_bias = 0.0, gate_style = "unit", attension_mlp_layers = 2, key_size = None, name = "relational_memory"): super(RelationalMemory, self).__init__(name="name") self._mem_slots = mem_slots # multi head size self._head_size = head_size self._num_heads = num_heads self._mem_size = self._head_size * self._num_heads if num_blocks < 1: raise ValueError("num_blocks must be >= 1, Got: {}.".format(num_blocks)) self._num_blocks = num_blocks self._forget_bias = forget_bias self._input_bias = input_bias if gate_style not in ["unit", "memory", None]: raise ValueError( r"gate_style must be one of ['unit', 'memory', None] Got {}".format(gate_style) ) self._gate_style = gate_style if attension_mlp_layers < 1: raise ValueError("attension_mlp_layers must be >= 1, Got: {}".format( attension_mlp_layers )) self._attention_mlp_layers = attension_mlp_layers # this size may be the size compatible with column num of memory self._key_size = key_size if key_size else self._head_size # init memory matrix def initial_state(self, batch_size, trainable = False): ''' # [batch, mem_slots, mem_slots] init_state = tf.eye(self._mem_slots, batch_shape=[batch_size]) if self._mem_size > self._mem_slots: difference = self._mem_size - self._mem_slots pad = tf.zeros((batch_size, self._mem_slots, difference)) init_state = tf.concat([init_state, pad], -1) elif self._mem_size < self._mem_slots: init_state = init_state[:, :, :self._mem_size] return init_state ''' init_state = tf.eye(self._mem_slots, self._mem_size, batch_shape=[batch_size]) return init_state def _multihead_attention(self, memory): key_size = self._key_size value_size = self._head_size qkv_size = 2 * key_size + value_size total_size = qkv_size * self._num_heads qkv = basic.BatchApply(basic.Linear(total_size))(memory) qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv) mem_slots = memory.get_shape().as_list()[1] qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads, qkv_size])(qkv) qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3]) q, k, v = tf.split(qkv_transpose, [key_size, key_size, key_size], -1) q *= qkv_size ** -0.5 dot_product = tf.matmul(q, k, transpose_b=True) weights = tf.nn.softmax(dot_product) output = tf.matmul(weights, v) output_transpose = tf.transpose(output, [0, 2, 1, 3]) new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose) return new_memory @property def state_size(self): return tf.TensorShape([self._mem_slots, self._mem_size]) @property def output_size(self): return tf.TensorShape(self._mem_slots * self._mem_size) def _calculate_gate_size(self): if self._gate_style == "unit": return self._mem_size elif self._gate_style == "memory": return 1 else: return 0 def _create_gates(self, inputs, memory): num_gates = 2 * self._calculate_gate_size() memory = tf.tanh(memory) # shape 2 inputs = basic.BatchFlatten()(inputs) gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs) # shape 3 gate_inputs = tf.expand_dims(gate_inputs, axis=1) gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory) # broadcast add to every row of memory gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2) input_gate, forget_gate = gates input_gate = tf.sigmoid(input_gate + self._input_bias) forget_gate = tf.sigmoid(forget_gate + self._forget_bias) return input_gate, forget_gate def _attend_over_memory(self, memory): attention_mlp = basic.BatchApply( mlp.MLP([self._mem_size] * self._attention_mlp_layers) ) for _ in range(self._num_blocks): attended_memory = self._multihead_attention(memory) memory = basic.BatchApply(layer_norm.LayerNorm())( memory + attended_memory ) memory = basic.BatchApply(layer_norm.LayerNorm())( attention_mlp(memory) + memory ) return memory def _build(self, inputs, memory, treat_input_as_matrix = False): if treat_input_as_matrix: inputs = basic.BatchFlatten(preserve_dims=2)(inputs) inputs_reshape =basic.BatchApply( basic.Linear(self._mem_size), n_dims=2 )(inputs) else: inputs = basic.BatchFlatten()(inputs) inputs = basic.Linear(self._mem_size)(inputs) inputs_reshape = tf.expand_dims(inputs, 1) memory_plus_input = tf.concat([memory, inputs_reshape], axis=1) next_memory = self._attend_over_memory(memory_plus_input) n = inputs_reshape.get_shape().as_list()[1] next_memory = next_memory[:,:-n,:] if self._gate_style == "unit" or self._gate_style == "memory": self._input_gate, self._forget_gate = self._create_gates( inputs_reshape, memory ) next_memory = self._input_gate * tf.tanh(next_memory) next_memory += self._forget_gate * memory output = basic.BatchFlatten()(next_memory) return output, next_memory @property def input_gate(self): self._ensure_is_connected() return self._input_gate @property def forget_gate(self): self._ensure_is_connected() return self._forget_gate if __name__ == "__main__": pass
先说一下代码风格,sonnet个人认为最方便的就是,batch级别的操作函数,从此再不用tf.map_fn。其继承 rnn_core.RNNCore 之后实现的_build基本对应tensorflow中RNN的call函数。另外不禁想吐糟一下,如果是我写self-attention很可能是Q K V分别定义weight之后运算,而作者用的是一个线性变换后split的方式,很好的继承了tensorflow rnn中诸组成部分的写作风格,而且最后记忆更新也是[M;x]整个送进去压缩之后slice.
了解了上述网络构造后可以用一个简单的例子试一下,比如用记忆网络估计递归运算,这也是原文中最简单的实验了(感兴趣可以看learning to execute)。有关递归结构神经网络的例子,感兴趣可以看TensorFlow Fold或者那篇博文 TensorFlow Fold 初探(一)——TreeLstm情感分类。这里完全用相同的sum函数数据来看一下效果。代码如下:
from model.study import RelationalMemory from sonnet.python.modules import basic import tensorflow as tf import random import numpy as np import os max_seq_len = 10 def random_example(fn, length = max_seq_len): length = random.randrange(1, length) data = [random.uniform(0,1) for _ in range(length)] result = fn(data) return data, result def random_generator(batch_num = 2, fn = sum): while True: req_x, req_mask, req_y = [], [], [] for _ in range(batch_num): data, result = random_example(fn) req_x.append(data) req_mask.append(len(data)) req_y.append(result) req_x, req_mask, req_y = map(np.array ,[req_x, req_mask, req_y]) yield req_x, req_mask, req_y # single model without sequence embedding class RetionalModel(object): def __init__(self, max_seq_len = max_seq_len, dnn_size = 100, epsilon = 1.0): self.max_seq_len = max_seq_len self.dnn_size = dnn_size # use epsilon to identify accurate rate self.epsilon = epsilon self.input = tf.placeholder(tf.float32, [None, max_seq_len]) self.input_mask = tf.placeholder(tf.int32, [None]) self.y = tf.placeholder(tf.float32, [None]) self.model_construct() def model_construct(self): relationalMemoryCell = RelationalMemory() outputs, state = tf.nn.dynamic_rnn(cell=relationalMemoryCell, inputs=tf.expand_dims(self.input, axis=-1), sequence_length=self.input_mask, dtype=tf.float32) flatten_outputs = basic.BatchFlatten()(outputs) h0 = tf.layers.dense(inputs=flatten_outputs, units=self.dnn_size, name="h0") self.prediction = tf.squeeze(tf.layers.dense(inputs=h0, units=1), name="prediction") self.accuracy = tf.reduce_mean(tf.cast((self.epsilon - tf.abs(self.prediction - self.y)) > 0, tf.float32)) self.loss = tf.losses.mean_squared_error(labels=self.y, predictions=self.prediction) self.train_op = tf.train.AdamOptimizer(0.001).minimize(self.loss) def simple_alg_seq_test(): train_gen = random_generator(batch_num=128) valid_gen = random_generator(batch_num=64) model = RetionalModel() step = 0 saver = tf.train.Saver() with tf.Session() as sess: if os.path.exists(r"E:\Coding\python\retionalSonnetStudy\model.ckpt.index"): print("restore exists") saver.restore(sess, save_path=r"E:\Coding\python\retionalSonnetStudy\model.ckpt") else: print("init global") sess.run(tf.global_variables_initializer()) #sess.run(tf.global_variables_initializer()) while True: req_x, req_mask, req_y = train_gen.__next__() req_x_pad = np.zeros(shape=[128, max_seq_len]) for e_idx, ele in enumerate(req_x): for c_idx, inner_ele in enumerate(ele): req_x_pad[e_idx][c_idx] = inner_ele _, loss, train_acc = sess.run([model.train_op, model.loss, model.accuracy], feed_dict={ model.input: req_x_pad, model.input_mask: req_mask, model.y : req_y }) step += 1 if step % 5 == 0: print("train, loss :{} acc : {}".format(loss, train_acc)) req_x, req_mask, req_y = valid_gen.__next__() req_x_pad = np.zeros(shape=[64, max_seq_len]) for e_idx, ele in enumerate(req_x): for c_idx, inner_ele in enumerate(ele): req_x_pad[e_idx][c_idx] = inner_ele loss, valid_acc = sess.run([model.loss, model.accuracy], feed_dict={ model.input: req_x_pad, model.input_mask: req_mask, model.y : req_y }) print("valid loss: {}, acc: {}".format(loss, valid_acc)) saver.save(sess, save_path=r"E:\Coding\python\retionalSonnetStudy\model.ckpt") if __name__ == "__main__": simple_alg_seq_test()
结果如下:
restore exists train, loss :0.24382832646369934 acc : 0.953125 valid loss: 0.2279301881790161, acc: 0.984375 train, loss :0.15681743621826172 acc : 0.9921875 valid loss: 0.10683506727218628, acc: 1.0 train, loss :0.09039808064699173 acc : 1.0 valid loss: 0.12752145528793335, acc: 1.0
相较于TensorFold Fold中的递归函数,可以吐槽一下这个记忆网络实现的长度,不过其记忆能力是具有一般性的。
上述记忆网络结构在近期的强化论文中也有使用,见:
Relational Deep Reinforcement Learning
论文链接:https://arxiv.org/abs/1806.01830
其基本网络结构见下图:
关键的relational module与本文记忆结构基本相同,只不过思路大致上是先提取图像特征,之后抽象出值函数及策略函数后使用强化学习方式进行优化。
其一个例子是应用于星际II小游戏,还没有看到开源数据集。。。。。。满满的怨念。。。。。。