A Discourse-Aware Attention Model for Abstractive Summarization of Long Documents 论文实现

论文链接:

https://arxiv.org/abs/1804.05685

 

论文目的:

篇章summary

 

简单说明:

         本文基本上可以理解为将一般seq2seq摘要模型的编码部分分成不同的篇章结构后进行attension信息加权的变体。除了这个特性外原文还有两部分优化,一个是类似CopyNet的原文拷贝机制(见原文Copying from source)、另一个是对于篇章权重进行调整(见原文Decoder coverage)。

本文的示例实现不包含上述两个优化。

 

模型结构图:


数据说明:

原文中使用的是arxiv论文的数据集(https://github.com/acohan/long-summarization),使用原文作为输入,摘要作为输出。

 

数据处理:

         由于输入数据集的大篇章特性(这也是文章解决的需求之一),要在数据处理上给出一些注意点。如英文不同词性的消去过程,及由消去过程可能引起的性能消耗。由于要消去词性,分词时就要得到相应的pos,再根据词本身用wordnet执行消去过程。考虑到性能要求,下面的实现用到了进程池。

import json
from dl_text import dl
from collections import Counter

from functools import reduce
from nltk.corpus import wordnet
from nltk import word_tokenize, pos_tag
from nltk.stem import WordNetLemmatizer
import re

import multiprocessing

def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return None

# map abstract_text article_text to summary and sec_text
def lower_tokenize(text):
    return dl.tokenize(dl.clean(text.lower()))

def varify_token(token):
    return re.match(r"^[a-zA-Z]+$", token) and (len(token) - 1)

lemmatizer = WordNetLemmatizer()
def lemmatize_sentence(sentence):
    sentence = dl.clean(sentence)
    res = []
    for word, pos in pos_tag(word_tokenize(sentence)):
        wordnet_pos = get_wordnet_pos(pos) or wordnet.NOUN
        res.append(lemmatizer.lemmatize(word, pos=wordnet_pos))

    res = list(filter(varify_token, res))
    return res


def process_dataset(filename, map_with_wordnet = True, multi_process = True):
    cnt = Counter()

    if multi_process:
        map_pool = multiprocessing.Pool(processes=7)

    finish_line = 0
    with open(filename + ".json", "w", encoding="utf-8") as o:
        with open(filename, "r", encoding="utf-8") as f:
            while True:
                text_line = f.readline().strip()
                if not text_line:
                    break

                json_obj = json.loads(text_line)
                article_text = json_obj["article_text"]
                abstract_text = json_obj["abstract_text"]

                summary = " ".join(abstract_text)
                sec_list = article_text

                if map_with_wordnet:
                    if multi_process:
                        sec_summary =  map_pool.map(lemmatize_sentence, sec_list + [summary])
                    else:
                        sec_summary = list(map(lemmatize_sentence ,sec_list + [summary]))
                else:
                    if multi_process:
                        sec_summary =  map_pool.map(lower_tokenize, sec_list + [summary])
                    else:
                        sec_summary = list(map(lower_tokenize ,sec_list + [summary]))

                cnt.update(reduce(lambda x, y: x + y ,sec_summary))
                sec_list ,summary = sec_summary[:-1], sec_summary[-1]

                o.write("{}\n".format(json.dumps({
                    "summary": summary,
                    "sec_list": sec_list
                })))

                finish_line += 1

                if finish_line % 100 == 0:
                    print("finish line: {}".format(finish_line))

        o.write("\n{}\n".format(json.dumps(dict(cnt.items()))))

def tokenize_train_test():
    process_dataset("data/test.txt")
    process_dataset("data/train.txt")

def valid_cnt_and_idx(use_percent = False):
    import mmap
    import numpy as np

    def getlastline(fname):
        with open(fname, "r") as source:
            mapping = mmap.mmap(source.fileno(), 0, access=mmap.ACCESS_READ)
        return mapping[mapping.rfind(b'\n', 0, -1)+1:]

    lastline_train = getlastline("data/train.txt.json").strip()
    cnt_obj_train = json.loads(lastline_train.decode("utf-8"))

    lastline_test = getlastline("data/test.txt.json").strip()
    cnt_obj_test = json.loads(lastline_test.decode("utf-8"))

    from collections import defaultdict
    def union_items(x, y):
        req = defaultdict(int)
        for k, v in x + y:
            req[k] += v
        return list(req.items())

    cnt_obj = reduce(union_items  ,[list(cnt_obj_train.items()), list(cnt_obj_test.items())])

    cnt = Counter(dict(cnt_obj))
    print("before filter num :")
    print(len(cnt))
    if use_percent:
        percent_list = np.percentile(list(cnt.values()), np.arange(0.0, 100.0, 10))
        print("percent_list :")
        print(percent_list)
        print("filter words by num: {}".format(np.mean(percent_list).astype(np.int32)))
        filter_cnt = Counter(dict(filter(lambda t2: t2[1] > np.mean(percent_list).astype(np.int32) , cnt.items())))
    else:
        filter_cnt = Counter(dict(cnt.most_common(5000)))
    print("after filter num :")
    print(len(filter_cnt))

    # filter file and encoding word to id
    word2idx = dict((w, i) for i, w in enumerate(filter_cnt.keys()))

    def filter_file(file_name, summary_len_threshold = 50):
        finish_line = 0
        with open(file_name + ".idx", "w") as o:
            with open(file_name, "r") as f:
                while True:
                    line = f.readline().strip()
                    if not line:
                        break
                    json_obj = json.loads(line)
                    summary = list(map(lambda f_token: word2idx[f_token] ,filter(lambda token: 0 if word2idx.get(token) is None else 1,json_obj["summary"])))
                    if len(summary) < summary_len_threshold:
                        continue

                    sec_list = list(map(lambda inner_list: list(map(lambda f_token: word2idx[f_token] ,filter(lambda token: 0 if word2idx.get(token) is None else 1, inner_list))), json_obj["sec_list"]))
                    o.write("{}\n".format(json.dumps({
                        "summary": summary,
                        "sec_list": sec_list
                    })))

                    finish_line += 1
                    if finish_line % 100 == 0:
                        print("finish line: {}".format(finish_line))

    filter_file("data/test.txt.json")
    filter_file("data/train.txt.json")
    with open("word2idx.json", "w") as f:
        json.dump({"word2idx": word2idx}, f)



if __name__ == "__main__":
    tokenize_train_test()
    valid_cnt_and_idx()

 

数据导出及模型构建:

import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np
import json

def data_generator(file_name, batch_num = 10,
                   input_max_sec_num = 5, input_max_seq_len = 500, output_max_seq_len = 50):
    start_idx = 0
    input = np.full(shape=[batch_num, input_max_sec_num, input_max_seq_len], fill_value=5000)
    input_seq_len = np.full(shape=[batch_num, input_max_sec_num], fill_value=0)

    output = np.full(shape=[batch_num, output_max_seq_len], fill_value=5000)
    output_seq_len = np.full(shape=[batch_num], fill_value=0)

    with open(file_name, "r") as f:
        while True:
            line = f.readline().strip()
            if not line:
                break
            json_obj = json.loads(line)

            summary = json_obj["summary"]
            sec_list = json_obj["sec_list"]

            for idx, ele in enumerate(summary[:output_max_seq_len]):
                output[start_idx][idx] = int(ele)
            output_seq_len[start_idx] = len(summary)

            for sec_idx, sec in enumerate(sec_list[:input_max_sec_num]):
                for seq_idx, ele in enumerate(sec[:input_max_seq_len]):
                    input[start_idx][sec_idx][seq_idx] = int(ele)
                input_seq_len[start_idx][sec_idx] = len(sec)

            start_idx += 1
            if start_idx == batch_num:
                yield input.astype(np.int32), input_seq_len.astype(np.int32), output.astype(np.int32), output_seq_len.astype(np.int32)

                start_idx = 0
                input = np.full(shape=[batch_num, input_max_sec_num, input_max_seq_len], fill_value=5000)
                input_seq_len = np.full(shape=[batch_num, input_max_sec_num], fill_value=0)

                output = np.full(shape=[batch_num, output_max_seq_len], fill_value=5000)
                output_seq_len = np.full(shape=[batch_num], fill_value=0)


class AAS(object):
    def __init__(self, word_embedding_dim = 30, word_size = 5000 + 1, hidden_state_dim = 10,
                 v_dim = 10, combine_state_dim = 100, batch_num = 10,
                 input_max_sec_num = 5, input_max_seq_len = 500, output_max_seq_len = 50 ,voc_dim = 5000 + 1):
        self.input = tf.placeholder(dtype=tf.int32, shape=[None, input_max_sec_num, input_max_seq_len])
        self.input_seq_len = tf.placeholder(dtype=tf.int32, shape=[None, input_max_sec_num])

        self.output = tf.placeholder(dtype=tf.int32, shape=[None, output_max_seq_len])
        self.output_seq_len = tf.placeholder(dtype=tf.int32, shape=[None])

        self.word_embedding_dim = word_embedding_dim
        self.word_size = word_size
        self.hidden_state_dim = hidden_state_dim
        self.v_dim = v_dim
        self.combine_state_dim = combine_state_dim
        self.batch_num = batch_num

        self.input_max_sec_num = input_max_sec_num
        self.input_max_seq_len = input_max_seq_len
        self.output_max_seq_len = output_max_seq_len

        self.voc_dim = voc_dim

        with tf.device('/cpu:0'), tf.name_scope("word_embedding"):
            # this layer maybe replaced by w2v or glove in the future
            self.word_W = tf.Variable(
                tf.random_uniform([self.word_size, self.word_embedding_dim], -1.0, 1.0),
                name="word_W")

        self.sec_batch_output = None
        self.sec_rnn_outputs = None
        self.doc_batch_output = None
        self.batch_pyt = None
        self.seq_loss = None
        self.output_states = None

        self.opt_construct()


    def word_embed_layer(self, input_word):
        with tf.name_scope("word_embedding"):
            embedded_word = tf.nn.embedding_lookup(self.word_W, input_word)
            return embedded_word

    # presume the output with max_len
    def bilstm_layer(self, input, seq_length, initial_state_fw = None, initial_state_bw = None):
        if initial_state_fw is None and initial_state_bw is None:
            fw_cell = rnn.BasicLSTMCell(self.hidden_state_dim, forget_bias=1., state_is_tuple=True)
            bw_cell = rnn.BasicLSTMCell(self.hidden_state_dim, forget_bias=1., state_is_tuple=True)

            rnn_outputs, output_states = tf.nn.bidirectional_dynamic_rnn(
                fw_cell, bw_cell, input, scope='bi-lstm-first', sequence_length=tf.cast(seq_length, tf.int32),
                dtype=tf.float32)
        else:
            fw_cell = rnn.BasicLSTMCell(int(initial_state_fw.h.get_shape()[-1]), forget_bias=1., state_is_tuple=True)
            bw_cell = rnn.BasicLSTMCell(int(initial_state_bw.h.get_shape()[-1]), forget_bias=1., state_is_tuple=True)

            rnn_outputs, output_states = tf.nn.bidirectional_dynamic_rnn(
                fw_cell, bw_cell, input, scope='bi-lstm-second', initial_state_fw=initial_state_fw,
                initial_state_bw=initial_state_bw, sequence_length=tf.cast(seq_length, tf.int32),
                dtype=tf.float32)

        self.output_states = output_states
        output_states = [output_states[0].h, output_states[1].h]

        return tf.concat(rnn_outputs, axis=-1, name="rnn_outputs") ,tf.concat(output_states, axis=-1, name="output_states")

    # input_state [batch, state_dim]
    # return [batch, self.combine_state_dim]
    def combine_state(self, input_state, name_index = 1):
        input_state_dim = int(input_state.get_shape()[-1])

        w = tf.get_variable(
            "w_{}".format(name_index),
            shape=[input_state_dim, self.combine_state_dim],
            initializer=tf.contrib.layers.xavier_initializer())
        b = tf.get_variable("b_{}".format(name_index),
                            shape=[self.combine_state_dim], initializer=tf.constant_initializer(1.0))

        return tf.nn.relu(tf.nn.xw_plus_b(input_state, w, b), name="combine_state")

    def encoder_layer(self, input, input_len):
        embedding_dim = int(input.get_shape()[-1])
        max_seq_len = int(input.get_shape()[-2])
        max_sec_len = int(input.get_shape()[-3])
        sec_input = tf.reshape(input, shape=[-1, max_seq_len, embedding_dim])
        sec_len_input = tf.reshape(input_len, shape=[-1])
        with tf.name_scope("sec_bilstm_layer"):
            with tf.variable_scope("combine_scope") as scope:
                sec_rnn_outputs ,sec_output_states = self.bilstm_layer(sec_input, sec_len_input)
                sec_output_last_dim = int(sec_output_states.get_shape()[-1])
                sec_rnn_outputs = self.combine_state(tf.reshape(sec_rnn_outputs, [-1, sec_output_last_dim]), name_index=1)
                scope.reuse_variables()
                sec_output_states = self.combine_state(sec_output_states, name_index=1)
                sec_output_last_dim = int(sec_output_states.get_shape()[-1])

            # [batch, sec, max_seq, hidden]
            self.sec_rnn_outputs = tf.reshape(sec_rnn_outputs, shape=[-1, max_sec_len, max_seq_len, sec_output_last_dim])
            # [batch, sec, hidden]
            self.sec_batch_output = tf.reshape(sec_output_states, shape=[-1, max_sec_len, sec_output_last_dim])

        doc_len_input = tf.reshape(tf.reduce_sum(tf.cast(input_len > 0, tf.float32), axis=-1), shape=[-1])
        with tf.name_scope("doc_bilstm_layer"):
            _ ,doc_output_states = self.bilstm_layer(self.sec_batch_output, doc_len_input)
            doc_output_states = self.combine_state(doc_output_states, name_index=2)
            doc_output_last_dim = int(doc_output_states.get_shape()[-1])

            # [batch, hidden]
            self.doc_batch_output = tf.reshape(doc_output_states, shape=[-1, doc_output_last_dim])

    # must init batch num without -1
    def score(self, batch_h_left, batch_h_right, name_idx = 1):
        # batch_h_left shape [batch_left, left_dim]
        # batch_h_right shape [batch_right, right_dim]

        left_dim = int(batch_h_left.get_shape()[-1])
        right_dim = int(batch_h_right.get_shape()[-1])
        left_batch_num = int(batch_h_left.get_shape()[0])
        right_batch_num = int(batch_h_right.get_shape()[0])

        w1 = tf.get_variable(
            "w1_{}".format(name_idx),
            shape=[self.v_dim, left_dim],
            initializer=tf.contrib.layers.xavier_initializer())
        w2 = tf.get_variable(
            "w2_{}".format(name_idx),
            shape=[self.v_dim, right_dim],
            initializer=tf.contrib.layers.xavier_initializer())
        b = tf.get_variable("b_{}".format(name_idx),
                            shape=[self.v_dim], initializer=tf.constant_initializer(1.0))

        left_part = tf.matmul(w1, tf.transpose(batch_h_left, [1, 0]))
        left_part = tf.tile(left_part, [right_batch_num, 1])
        right_part = tf.matmul(w2, tf.transpose(batch_h_right,
                                                 [1, 0]))
        right_part_list = tf.unstack(right_part, axis=1)
        right_part = tf.expand_dims(tf.concat(right_part_list, axis=-1), axis=-1)
        right_part = tf.tile(right_part, [1, left_batch_num])

        b_part = tf.tile(tf.expand_dims(b, axis=-1), [right_batch_num, left_batch_num])

        assert left_part.get_shape() == right_part.get_shape() == b_part.get_shape()
        linear = left_part + right_part + b_part

        va = tf.get_variable(
            "va_{}".format(name_idx),
            shape=[self.v_dim, 1],
            initializer=tf.constant_initializer(1.0))
        va_part = tf.tile(va, [right_batch_num, left_batch_num])
        score_before_reduce_sum = tf.multiply(va_part, tf.nn.tanh(linear))
        score_list = []
        for i in range(0, int(score_before_reduce_sum.get_shape()[0]), self.v_dim):
            score_list.append(tf.reduce_sum(tf.slice(score_before_reduce_sum, [i, 0], [self.v_dim, -1]), axis=0, keep_dims=True))
        return tf.concat(score_list, axis=0, name="score_matrix")

    def decoder_layer(self, input, input_len):
        # the section level rnn last hidden state used as initial state
        # of decoder
        max_seq_len = int(input.get_shape()[-2])

        with tf.name_scope("output_bilstm_layer"):
            batch_outputs ,batch_output_states = self.bilstm_layer(input, input_len,
                        initial_state_fw=self.output_states[0], initial_state_bw=self.output_states[1])

            batch_outputs_dim = int(batch_outputs.get_shape()[-1])
            batch_outputs = self.combine_state(tf.reshape(batch_outputs, [-1, batch_outputs_dim]), name_index=3)

            # [batch, max_seq_len, batch_outputs_dim]
            self.batch_outputs = tf.reshape(batch_outputs, shape=[-1, max_seq_len ,batch_outputs_dim])

    def tile_beta_matrix(self, beta_matrix, max_seq_len):
        time_dim = int(beta_matrix.get_shape()[0])
        beta_list = tf.unstack(beta_matrix, axis=-1)
        beta_single = tf.expand_dims(tf.concat(beta_list, axis=-1), axis=-1)
        beta_tiled = tf.tile(beta_single, [1, max_seq_len])
        beta_tiled_list = []
        for i in range(0, int(beta_tiled.get_shape()[0]), time_dim):
            beta_tiled_list.append(tf.slice(beta_tiled, [i, 0], [time_dim, -1]))
        return tf.concat(beta_tiled_list, axis=-1)

    def model_construct(self):
        with tf.variable_scope("word_embedding") as scope:
            embedding_input = self.word_embed_layer(self.input)
            scope.reuse_variables()
            embedding_output = self.word_embed_layer(self.output)

        # construct encoder decoder with property setting
        self.encoder_layer(embedding_input, self.input_seq_len)
        self.decoder_layer(embedding_output, self.output_seq_len)

        # [batch, sec, hidden] self.sec_batch_output
        # [batch, max_seq_len, hidden] self.batch_outputs
        # [batch, sec, max_seq, hidden] self.sec_rnn_outputs
        pyt_list = []
        with tf.variable_scope("score_scope") as scope:
            for i in range(0, self.batch_num):
                sec_hidden = tf.squeeze(tf.slice(self.sec_batch_output, [i, 0, 0], [1, -1, -1]))
                output_hidden = tf.squeeze(tf.slice(self.batch_outputs, [i, 0, 0], [1, -1, -1]))

                # [max_seq_len, sec] [t_rows, sec_cols]
                beta_score_matrix = tf.nn.softmax(self.score(sec_hidden, output_hidden, name_idx=1), dim=-1)

                sec_seq_hidden = tf.squeeze(tf.slice(self.sec_rnn_outputs, [i, 0, 0, 0], [1, -1, -1, -1]))
                hidden_dim = int(sec_seq_hidden.get_shape()[-1])
                dict_hidden = tf.reshape(sec_seq_hidden, [-1, hidden_dim])
                dict_score_matrix = self.score(dict_hidden, output_hidden, name_idx=2)

                max_seq_len = int(self.sec_rnn_outputs.get_shape()[2])
                beta_matrix_tiled = self.tile_beta_matrix(beta_score_matrix, max_seq_len)

                # assert shape equal
                assert dict_score_matrix.get_shape() == beta_matrix_tiled.get_shape()
                alpha_ori_score_matrix = tf.nn.softmax(tf.multiply(dict_score_matrix, beta_matrix_tiled), dim=-1)

                # [t, hidden]
                c_matrix = tf.matmul(alpha_ori_score_matrix, dict_hidden)
                h_matrix = output_hidden
                c_hidden_dim = int(c_matrix.get_shape()[-1])
                h_hidden_dim = int(h_matrix.get_shape()[-1])

                lin_w1 = tf.get_variable(
                    "lin_w1",
                    shape=[self.voc_dim, h_hidden_dim],
                    initializer=tf.contrib.layers.xavier_initializer())
                lin_w2 = tf.get_variable(
                    "lin_w2",
                    shape=[self.voc_dim, c_hidden_dim],
                    initializer=tf.contrib.layers.xavier_initializer())
                lin_b = tf.get_variable("lin_b",
                                    shape=[int(c_matrix.get_shape()[0])], initializer=tf.constant_initializer(1.0))

                # [voc_dim, t]
                linear = tf.nn.bias_add(tf.matmul(lin_w1, tf.transpose(h_matrix, [1, 0])) + tf.matmul(lin_w2, tf.transpose(c_matrix, [1, 0]))
                               , lin_b)
                V = tf.get_variable(
                    "V",
                    shape=[self.voc_dim, self.voc_dim],
                    initializer=tf.contrib.layers.xavier_initializer())

                # [t, voc_dim]
                pyt = tf.transpose(tf.matmul(V, linear), [1, 0])
                pyt_list.append(tf.expand_dims(pyt, dim=0))
                scope.reuse_variables()

        # [batch, t, voc_dim]
        batch_pyt = tf.concat(pyt_list, axis=0)

        return batch_pyt

    def opt_construct(self):
        self.batch_pyt = self.model_construct()
        logits = self.batch_pyt
        targets = tf.cast(self.output, tf.int32)

        output_seq_len_mask = self.output_seq_len + 1
        output_max_len = int(self.output.get_shape()[-1])

        weights = tf.cast(tf.slice(1 - tf.cumsum(tf.one_hot(output_seq_len_mask, depth=output_max_len + 1), axis=-1), [0, 0], [-1, output_max_len]), tf.float32)

        self.seq_loss = tf.contrib.seq2seq.sequence_loss(logits = logits, targets = targets, weights = weights)
        self.pred = tf.cast(tf.argmax(tf.nn.softmax(self.batch_pyt, dim=-1), axis=-1), dtype=tf.int32)
        self.accuracy = tf.reduce_mean(tf.cast(tf.equal(self.pred, self.output), tf.float32))

        self.train_op = tf.train.AdamOptimizer(0.001).minimize(self.seq_loss)

    @staticmethod
    def train():
        train_gen = data_generator("data/train.txt.json.idx")
        test_gen = data_generator("data/test.txt.json.idx")

        aas_ext = AAS()
        num_epochs = 100
        now_epoch = 0
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print("model construct end")

            for i in range(int(1e20)):
                try:
                    input, input_seq_len, output, output_seq_len = train_gen.__next__()
                except:
                    print("epoch {} end".format(now_epoch))
                    now_epoch += 1
                    if now_epoch == num_epochs:
                        return

                    train_gen = data_generator("data/train.txt.json.idx")
                    input, input_seq_len, output, output_seq_len = train_gen.__next__()

                _, loss = sess.run(
                    [aas_ext.train_op, aas_ext.seq_loss],
                    feed_dict={
                        aas_ext.input: input,
                        aas_ext.input_seq_len: input_seq_len,
                        aas_ext.output: output,
                        aas_ext.output_seq_len: output_seq_len
                    }
                )

                if i % 100 == 0:
                    print("train loss: {}".format(loss))
                if i % 1000 == 0:
                    try:
                        input, input_seq_len, output, output_seq_len = test_gen.__next__()
                    except:
                        test_gen = data_generator("data/test.txt.json.idx")
                        input, input_seq_len, output, output_seq_len = test_gen.__next__()

                    loss = sess.run(
                        aas_ext.seq_loss,
                        feed_dict={
                            aas_ext.input: input,
                            aas_ext.input_seq_len: input_seq_len,
                            aas_ext.output: output,
                            aas_ext.output_seq_len: output_seq_len
                        }
                    )

                    print("test loss: {}".format(loss))


if __name__ == "__main__":
    AAS.train()



猜你喜欢

转载自blog.csdn.net/sinat_30665603/article/details/80174101