论文链接:
https://arxiv.org/abs/1804.05685
论文目的:
篇章summary
简单说明:
本文基本上可以理解为将一般seq2seq摘要模型的编码部分分成不同的篇章结构后进行attension信息加权的变体。除了这个特性外原文还有两部分优化,一个是类似CopyNet的原文拷贝机制(见原文Copying from source)、另一个是对于篇章权重进行调整(见原文Decoder coverage)。
本文的示例实现不包含上述两个优化。
模型结构图:
数据说明:
原文中使用的是arxiv论文的数据集(https://github.com/acohan/long-summarization),使用原文作为输入,摘要作为输出。
数据处理:
由于输入数据集的大篇章特性(这也是文章解决的需求之一),要在数据处理上给出一些注意点。如英文不同词性的消去过程,及由消去过程可能引起的性能消耗。由于要消去词性,分词时就要得到相应的pos,再根据词本身用wordnet执行消去过程。考虑到性能要求,下面的实现用到了进程池。
import json from dl_text import dl from collections import Counter from functools import reduce from nltk.corpus import wordnet from nltk import word_tokenize, pos_tag from nltk.stem import WordNetLemmatizer import re import multiprocessing def get_wordnet_pos(treebank_tag): if treebank_tag.startswith('J'): return wordnet.ADJ elif treebank_tag.startswith('V'): return wordnet.VERB elif treebank_tag.startswith('N'): return wordnet.NOUN elif treebank_tag.startswith('R'): return wordnet.ADV else: return None # map abstract_text article_text to summary and sec_text def lower_tokenize(text): return dl.tokenize(dl.clean(text.lower())) def varify_token(token): return re.match(r"^[a-zA-Z]+$", token) and (len(token) - 1) lemmatizer = WordNetLemmatizer() def lemmatize_sentence(sentence): sentence = dl.clean(sentence) res = [] for word, pos in pos_tag(word_tokenize(sentence)): wordnet_pos = get_wordnet_pos(pos) or wordnet.NOUN res.append(lemmatizer.lemmatize(word, pos=wordnet_pos)) res = list(filter(varify_token, res)) return res def process_dataset(filename, map_with_wordnet = True, multi_process = True): cnt = Counter() if multi_process: map_pool = multiprocessing.Pool(processes=7) finish_line = 0 with open(filename + ".json", "w", encoding="utf-8") as o: with open(filename, "r", encoding="utf-8") as f: while True: text_line = f.readline().strip() if not text_line: break json_obj = json.loads(text_line) article_text = json_obj["article_text"] abstract_text = json_obj["abstract_text"] summary = " ".join(abstract_text) sec_list = article_text if map_with_wordnet: if multi_process: sec_summary = map_pool.map(lemmatize_sentence, sec_list + [summary]) else: sec_summary = list(map(lemmatize_sentence ,sec_list + [summary])) else: if multi_process: sec_summary = map_pool.map(lower_tokenize, sec_list + [summary]) else: sec_summary = list(map(lower_tokenize ,sec_list + [summary])) cnt.update(reduce(lambda x, y: x + y ,sec_summary)) sec_list ,summary = sec_summary[:-1], sec_summary[-1] o.write("{}\n".format(json.dumps({ "summary": summary, "sec_list": sec_list }))) finish_line += 1 if finish_line % 100 == 0: print("finish line: {}".format(finish_line)) o.write("\n{}\n".format(json.dumps(dict(cnt.items())))) def tokenize_train_test(): process_dataset("data/test.txt") process_dataset("data/train.txt") def valid_cnt_and_idx(use_percent = False): import mmap import numpy as np def getlastline(fname): with open(fname, "r") as source: mapping = mmap.mmap(source.fileno(), 0, access=mmap.ACCESS_READ) return mapping[mapping.rfind(b'\n', 0, -1)+1:] lastline_train = getlastline("data/train.txt.json").strip() cnt_obj_train = json.loads(lastline_train.decode("utf-8")) lastline_test = getlastline("data/test.txt.json").strip() cnt_obj_test = json.loads(lastline_test.decode("utf-8")) from collections import defaultdict def union_items(x, y): req = defaultdict(int) for k, v in x + y: req[k] += v return list(req.items()) cnt_obj = reduce(union_items ,[list(cnt_obj_train.items()), list(cnt_obj_test.items())]) cnt = Counter(dict(cnt_obj)) print("before filter num :") print(len(cnt)) if use_percent: percent_list = np.percentile(list(cnt.values()), np.arange(0.0, 100.0, 10)) print("percent_list :") print(percent_list) print("filter words by num: {}".format(np.mean(percent_list).astype(np.int32))) filter_cnt = Counter(dict(filter(lambda t2: t2[1] > np.mean(percent_list).astype(np.int32) , cnt.items()))) else: filter_cnt = Counter(dict(cnt.most_common(5000))) print("after filter num :") print(len(filter_cnt)) # filter file and encoding word to id word2idx = dict((w, i) for i, w in enumerate(filter_cnt.keys())) def filter_file(file_name, summary_len_threshold = 50): finish_line = 0 with open(file_name + ".idx", "w") as o: with open(file_name, "r") as f: while True: line = f.readline().strip() if not line: break json_obj = json.loads(line) summary = list(map(lambda f_token: word2idx[f_token] ,filter(lambda token: 0 if word2idx.get(token) is None else 1,json_obj["summary"]))) if len(summary) < summary_len_threshold: continue sec_list = list(map(lambda inner_list: list(map(lambda f_token: word2idx[f_token] ,filter(lambda token: 0 if word2idx.get(token) is None else 1, inner_list))), json_obj["sec_list"])) o.write("{}\n".format(json.dumps({ "summary": summary, "sec_list": sec_list }))) finish_line += 1 if finish_line % 100 == 0: print("finish line: {}".format(finish_line)) filter_file("data/test.txt.json") filter_file("data/train.txt.json") with open("word2idx.json", "w") as f: json.dump({"word2idx": word2idx}, f) if __name__ == "__main__": tokenize_train_test() valid_cnt_and_idx()
数据导出及模型构建:
import tensorflow as tf from tensorflow.contrib import rnn import numpy as np import json def data_generator(file_name, batch_num = 10, input_max_sec_num = 5, input_max_seq_len = 500, output_max_seq_len = 50): start_idx = 0 input = np.full(shape=[batch_num, input_max_sec_num, input_max_seq_len], fill_value=5000) input_seq_len = np.full(shape=[batch_num, input_max_sec_num], fill_value=0) output = np.full(shape=[batch_num, output_max_seq_len], fill_value=5000) output_seq_len = np.full(shape=[batch_num], fill_value=0) with open(file_name, "r") as f: while True: line = f.readline().strip() if not line: break json_obj = json.loads(line) summary = json_obj["summary"] sec_list = json_obj["sec_list"] for idx, ele in enumerate(summary[:output_max_seq_len]): output[start_idx][idx] = int(ele) output_seq_len[start_idx] = len(summary) for sec_idx, sec in enumerate(sec_list[:input_max_sec_num]): for seq_idx, ele in enumerate(sec[:input_max_seq_len]): input[start_idx][sec_idx][seq_idx] = int(ele) input_seq_len[start_idx][sec_idx] = len(sec) start_idx += 1 if start_idx == batch_num: yield input.astype(np.int32), input_seq_len.astype(np.int32), output.astype(np.int32), output_seq_len.astype(np.int32) start_idx = 0 input = np.full(shape=[batch_num, input_max_sec_num, input_max_seq_len], fill_value=5000) input_seq_len = np.full(shape=[batch_num, input_max_sec_num], fill_value=0) output = np.full(shape=[batch_num, output_max_seq_len], fill_value=5000) output_seq_len = np.full(shape=[batch_num], fill_value=0) class AAS(object): def __init__(self, word_embedding_dim = 30, word_size = 5000 + 1, hidden_state_dim = 10, v_dim = 10, combine_state_dim = 100, batch_num = 10, input_max_sec_num = 5, input_max_seq_len = 500, output_max_seq_len = 50 ,voc_dim = 5000 + 1): self.input = tf.placeholder(dtype=tf.int32, shape=[None, input_max_sec_num, input_max_seq_len]) self.input_seq_len = tf.placeholder(dtype=tf.int32, shape=[None, input_max_sec_num]) self.output = tf.placeholder(dtype=tf.int32, shape=[None, output_max_seq_len]) self.output_seq_len = tf.placeholder(dtype=tf.int32, shape=[None]) self.word_embedding_dim = word_embedding_dim self.word_size = word_size self.hidden_state_dim = hidden_state_dim self.v_dim = v_dim self.combine_state_dim = combine_state_dim self.batch_num = batch_num self.input_max_sec_num = input_max_sec_num self.input_max_seq_len = input_max_seq_len self.output_max_seq_len = output_max_seq_len self.voc_dim = voc_dim with tf.device('/cpu:0'), tf.name_scope("word_embedding"): # this layer maybe replaced by w2v or glove in the future self.word_W = tf.Variable( tf.random_uniform([self.word_size, self.word_embedding_dim], -1.0, 1.0), name="word_W") self.sec_batch_output = None self.sec_rnn_outputs = None self.doc_batch_output = None self.batch_pyt = None self.seq_loss = None self.output_states = None self.opt_construct() def word_embed_layer(self, input_word): with tf.name_scope("word_embedding"): embedded_word = tf.nn.embedding_lookup(self.word_W, input_word) return embedded_word # presume the output with max_len def bilstm_layer(self, input, seq_length, initial_state_fw = None, initial_state_bw = None): if initial_state_fw is None and initial_state_bw is None: fw_cell = rnn.BasicLSTMCell(self.hidden_state_dim, forget_bias=1., state_is_tuple=True) bw_cell = rnn.BasicLSTMCell(self.hidden_state_dim, forget_bias=1., state_is_tuple=True) rnn_outputs, output_states = tf.nn.bidirectional_dynamic_rnn( fw_cell, bw_cell, input, scope='bi-lstm-first', sequence_length=tf.cast(seq_length, tf.int32), dtype=tf.float32) else: fw_cell = rnn.BasicLSTMCell(int(initial_state_fw.h.get_shape()[-1]), forget_bias=1., state_is_tuple=True) bw_cell = rnn.BasicLSTMCell(int(initial_state_bw.h.get_shape()[-1]), forget_bias=1., state_is_tuple=True) rnn_outputs, output_states = tf.nn.bidirectional_dynamic_rnn( fw_cell, bw_cell, input, scope='bi-lstm-second', initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, sequence_length=tf.cast(seq_length, tf.int32), dtype=tf.float32) self.output_states = output_states output_states = [output_states[0].h, output_states[1].h] return tf.concat(rnn_outputs, axis=-1, name="rnn_outputs") ,tf.concat(output_states, axis=-1, name="output_states") # input_state [batch, state_dim] # return [batch, self.combine_state_dim] def combine_state(self, input_state, name_index = 1): input_state_dim = int(input_state.get_shape()[-1]) w = tf.get_variable( "w_{}".format(name_index), shape=[input_state_dim, self.combine_state_dim], initializer=tf.contrib.layers.xavier_initializer()) b = tf.get_variable("b_{}".format(name_index), shape=[self.combine_state_dim], initializer=tf.constant_initializer(1.0)) return tf.nn.relu(tf.nn.xw_plus_b(input_state, w, b), name="combine_state") def encoder_layer(self, input, input_len): embedding_dim = int(input.get_shape()[-1]) max_seq_len = int(input.get_shape()[-2]) max_sec_len = int(input.get_shape()[-3]) sec_input = tf.reshape(input, shape=[-1, max_seq_len, embedding_dim]) sec_len_input = tf.reshape(input_len, shape=[-1]) with tf.name_scope("sec_bilstm_layer"): with tf.variable_scope("combine_scope") as scope: sec_rnn_outputs ,sec_output_states = self.bilstm_layer(sec_input, sec_len_input) sec_output_last_dim = int(sec_output_states.get_shape()[-1]) sec_rnn_outputs = self.combine_state(tf.reshape(sec_rnn_outputs, [-1, sec_output_last_dim]), name_index=1) scope.reuse_variables() sec_output_states = self.combine_state(sec_output_states, name_index=1) sec_output_last_dim = int(sec_output_states.get_shape()[-1]) # [batch, sec, max_seq, hidden] self.sec_rnn_outputs = tf.reshape(sec_rnn_outputs, shape=[-1, max_sec_len, max_seq_len, sec_output_last_dim]) # [batch, sec, hidden] self.sec_batch_output = tf.reshape(sec_output_states, shape=[-1, max_sec_len, sec_output_last_dim]) doc_len_input = tf.reshape(tf.reduce_sum(tf.cast(input_len > 0, tf.float32), axis=-1), shape=[-1]) with tf.name_scope("doc_bilstm_layer"): _ ,doc_output_states = self.bilstm_layer(self.sec_batch_output, doc_len_input) doc_output_states = self.combine_state(doc_output_states, name_index=2) doc_output_last_dim = int(doc_output_states.get_shape()[-1]) # [batch, hidden] self.doc_batch_output = tf.reshape(doc_output_states, shape=[-1, doc_output_last_dim]) # must init batch num without -1 def score(self, batch_h_left, batch_h_right, name_idx = 1): # batch_h_left shape [batch_left, left_dim] # batch_h_right shape [batch_right, right_dim] left_dim = int(batch_h_left.get_shape()[-1]) right_dim = int(batch_h_right.get_shape()[-1]) left_batch_num = int(batch_h_left.get_shape()[0]) right_batch_num = int(batch_h_right.get_shape()[0]) w1 = tf.get_variable( "w1_{}".format(name_idx), shape=[self.v_dim, left_dim], initializer=tf.contrib.layers.xavier_initializer()) w2 = tf.get_variable( "w2_{}".format(name_idx), shape=[self.v_dim, right_dim], initializer=tf.contrib.layers.xavier_initializer()) b = tf.get_variable("b_{}".format(name_idx), shape=[self.v_dim], initializer=tf.constant_initializer(1.0)) left_part = tf.matmul(w1, tf.transpose(batch_h_left, [1, 0])) left_part = tf.tile(left_part, [right_batch_num, 1]) right_part = tf.matmul(w2, tf.transpose(batch_h_right, [1, 0])) right_part_list = tf.unstack(right_part, axis=1) right_part = tf.expand_dims(tf.concat(right_part_list, axis=-1), axis=-1) right_part = tf.tile(right_part, [1, left_batch_num]) b_part = tf.tile(tf.expand_dims(b, axis=-1), [right_batch_num, left_batch_num]) assert left_part.get_shape() == right_part.get_shape() == b_part.get_shape() linear = left_part + right_part + b_part va = tf.get_variable( "va_{}".format(name_idx), shape=[self.v_dim, 1], initializer=tf.constant_initializer(1.0)) va_part = tf.tile(va, [right_batch_num, left_batch_num]) score_before_reduce_sum = tf.multiply(va_part, tf.nn.tanh(linear)) score_list = [] for i in range(0, int(score_before_reduce_sum.get_shape()[0]), self.v_dim): score_list.append(tf.reduce_sum(tf.slice(score_before_reduce_sum, [i, 0], [self.v_dim, -1]), axis=0, keep_dims=True)) return tf.concat(score_list, axis=0, name="score_matrix") def decoder_layer(self, input, input_len): # the section level rnn last hidden state used as initial state # of decoder max_seq_len = int(input.get_shape()[-2]) with tf.name_scope("output_bilstm_layer"): batch_outputs ,batch_output_states = self.bilstm_layer(input, input_len, initial_state_fw=self.output_states[0], initial_state_bw=self.output_states[1]) batch_outputs_dim = int(batch_outputs.get_shape()[-1]) batch_outputs = self.combine_state(tf.reshape(batch_outputs, [-1, batch_outputs_dim]), name_index=3) # [batch, max_seq_len, batch_outputs_dim] self.batch_outputs = tf.reshape(batch_outputs, shape=[-1, max_seq_len ,batch_outputs_dim]) def tile_beta_matrix(self, beta_matrix, max_seq_len): time_dim = int(beta_matrix.get_shape()[0]) beta_list = tf.unstack(beta_matrix, axis=-1) beta_single = tf.expand_dims(tf.concat(beta_list, axis=-1), axis=-1) beta_tiled = tf.tile(beta_single, [1, max_seq_len]) beta_tiled_list = [] for i in range(0, int(beta_tiled.get_shape()[0]), time_dim): beta_tiled_list.append(tf.slice(beta_tiled, [i, 0], [time_dim, -1])) return tf.concat(beta_tiled_list, axis=-1) def model_construct(self): with tf.variable_scope("word_embedding") as scope: embedding_input = self.word_embed_layer(self.input) scope.reuse_variables() embedding_output = self.word_embed_layer(self.output) # construct encoder decoder with property setting self.encoder_layer(embedding_input, self.input_seq_len) self.decoder_layer(embedding_output, self.output_seq_len) # [batch, sec, hidden] self.sec_batch_output # [batch, max_seq_len, hidden] self.batch_outputs # [batch, sec, max_seq, hidden] self.sec_rnn_outputs pyt_list = [] with tf.variable_scope("score_scope") as scope: for i in range(0, self.batch_num): sec_hidden = tf.squeeze(tf.slice(self.sec_batch_output, [i, 0, 0], [1, -1, -1])) output_hidden = tf.squeeze(tf.slice(self.batch_outputs, [i, 0, 0], [1, -1, -1])) # [max_seq_len, sec] [t_rows, sec_cols] beta_score_matrix = tf.nn.softmax(self.score(sec_hidden, output_hidden, name_idx=1), dim=-1) sec_seq_hidden = tf.squeeze(tf.slice(self.sec_rnn_outputs, [i, 0, 0, 0], [1, -1, -1, -1])) hidden_dim = int(sec_seq_hidden.get_shape()[-1]) dict_hidden = tf.reshape(sec_seq_hidden, [-1, hidden_dim]) dict_score_matrix = self.score(dict_hidden, output_hidden, name_idx=2) max_seq_len = int(self.sec_rnn_outputs.get_shape()[2]) beta_matrix_tiled = self.tile_beta_matrix(beta_score_matrix, max_seq_len) # assert shape equal assert dict_score_matrix.get_shape() == beta_matrix_tiled.get_shape() alpha_ori_score_matrix = tf.nn.softmax(tf.multiply(dict_score_matrix, beta_matrix_tiled), dim=-1) # [t, hidden] c_matrix = tf.matmul(alpha_ori_score_matrix, dict_hidden) h_matrix = output_hidden c_hidden_dim = int(c_matrix.get_shape()[-1]) h_hidden_dim = int(h_matrix.get_shape()[-1]) lin_w1 = tf.get_variable( "lin_w1", shape=[self.voc_dim, h_hidden_dim], initializer=tf.contrib.layers.xavier_initializer()) lin_w2 = tf.get_variable( "lin_w2", shape=[self.voc_dim, c_hidden_dim], initializer=tf.contrib.layers.xavier_initializer()) lin_b = tf.get_variable("lin_b", shape=[int(c_matrix.get_shape()[0])], initializer=tf.constant_initializer(1.0)) # [voc_dim, t] linear = tf.nn.bias_add(tf.matmul(lin_w1, tf.transpose(h_matrix, [1, 0])) + tf.matmul(lin_w2, tf.transpose(c_matrix, [1, 0])) , lin_b) V = tf.get_variable( "V", shape=[self.voc_dim, self.voc_dim], initializer=tf.contrib.layers.xavier_initializer()) # [t, voc_dim] pyt = tf.transpose(tf.matmul(V, linear), [1, 0]) pyt_list.append(tf.expand_dims(pyt, dim=0)) scope.reuse_variables() # [batch, t, voc_dim] batch_pyt = tf.concat(pyt_list, axis=0) return batch_pyt def opt_construct(self): self.batch_pyt = self.model_construct() logits = self.batch_pyt targets = tf.cast(self.output, tf.int32) output_seq_len_mask = self.output_seq_len + 1 output_max_len = int(self.output.get_shape()[-1]) weights = tf.cast(tf.slice(1 - tf.cumsum(tf.one_hot(output_seq_len_mask, depth=output_max_len + 1), axis=-1), [0, 0], [-1, output_max_len]), tf.float32) self.seq_loss = tf.contrib.seq2seq.sequence_loss(logits = logits, targets = targets, weights = weights) self.pred = tf.cast(tf.argmax(tf.nn.softmax(self.batch_pyt, dim=-1), axis=-1), dtype=tf.int32) self.accuracy = tf.reduce_mean(tf.cast(tf.equal(self.pred, self.output), tf.float32)) self.train_op = tf.train.AdamOptimizer(0.001).minimize(self.seq_loss) @staticmethod def train(): train_gen = data_generator("data/train.txt.json.idx") test_gen = data_generator("data/test.txt.json.idx") aas_ext = AAS() num_epochs = 100 now_epoch = 0 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print("model construct end") for i in range(int(1e20)): try: input, input_seq_len, output, output_seq_len = train_gen.__next__() except: print("epoch {} end".format(now_epoch)) now_epoch += 1 if now_epoch == num_epochs: return train_gen = data_generator("data/train.txt.json.idx") input, input_seq_len, output, output_seq_len = train_gen.__next__() _, loss = sess.run( [aas_ext.train_op, aas_ext.seq_loss], feed_dict={ aas_ext.input: input, aas_ext.input_seq_len: input_seq_len, aas_ext.output: output, aas_ext.output_seq_len: output_seq_len } ) if i % 100 == 0: print("train loss: {}".format(loss)) if i % 1000 == 0: try: input, input_seq_len, output, output_seq_len = test_gen.__next__() except: test_gen = data_generator("data/test.txt.json.idx") input, input_seq_len, output, output_seq_len = test_gen.__next__() loss = sess.run( aas_ext.seq_loss, feed_dict={ aas_ext.input: input, aas_ext.input_seq_len: input_seq_len, aas_ext.output: output, aas_ext.output_seq_len: output_seq_len } ) print("test loss: {}".format(loss)) if __name__ == "__main__": AAS.train()