BI-DIRECTIONAL ATTENSION FLOW FOR MACHINE COMPREHENSION 论文阅读及实现

论文链接:

https://arxiv.org/abs/1611.01603

 

论文目的:

context中选取与query相关的内容(本博文举的例子为context下标选取),可以看作是一种QA结构。类似的可以参看前文——End-To-End Memory Networks 论文阅读。

 

模型特征:

总体特征:包含3个层次embedding: character-level word-level contextual-level

使用双向attension来获得query-aware context表征。



数据说明:

这里使用上文 Deep Learningfor Extreme Multi-label Text Classification 实现

QA数据集,将Question部分作为Context Title作为QueryTag作为所求的answer(这里要求TagQuestion中有,会执行相应的数据过滤),这里我们之估计Start indice,因为Tag一般是单个单词。

 

数据准备:

取上次处理得到的数据 X_y_file_12e5.txt

判断样本是否符合建模要求,并进行训练集测试集分割:

def context_query_data_process():
    def valid_func(question_text, tag_text):
        q_list = question_text.split(" ")
        t_list = tag_text.split(" ")
        for t in t_list:
            if t in q_list:
                return "{}\t{}".format(question_text, t)
    with open("context_query_tag_test.txt", "w") as test_o:
        with open("context_query_tag_train.txt", "w") as train_o:
            with open("X_y_file_12e5.txt") as f:
                line_num = 0
                while True:
                    line = f.readline()
                    if not line:
                        break
                    question_text, s, title_text, tag_text = line[:-1].split("\t")
                    valid = valid_func(question_text, tag_text)
                    if valid:
                        ff, s = valid.split("\t")
                        if line_num % 10 >= 8:
                            test_o.write("{}\t{}\t{}\n".format(ff, title_text, s))
                        else:
                            train_o.write("{}\t{}\t{}\n".format(ff, title_text, s))
                        line_num += 1
                        if line_num % 10000 == 0:
                            print("line_num :{}".format(line_num))

数据集编码:

def index_cq():
    from functools import reduce
    from collections import Counter

    with open("context_query_tag_test.txt") as f:
        test_context = f.read().replace("\t", " ").replace("\n", " ")
        cnt = Counter(test_context.split(" "))
        words = list(map(lambda x: x[0],cnt.most_common(10000)))
        all_char_set = reduce(lambda x, y: x.union(y),map(set, words)).union(set(["<", ">"]))
        print("w c start")

    word2idx = dict([(w, i) for i, w in enumerate(words)] + [("<unk>", len(words)) ,("<pad>", len(words) + 1)])
    char2idx = dict((c, i) for i, c in enumerate(all_char_set))
    print("w c end")

    def map_word(input_list):
        return map(lambda x: word2idx.get(x, word2idx["<unk>"]),input_list)
    def map_char(input_list):
        char_nest_map = map(lambda char_list: map(lambda x: str(char2idx[x]), char_list), map(list ,input_list))
        return map(lambda char_map: "_".join(char_map), char_nest_map)

    with open("cqtt.txt", "w") as o:
        with open("context_query_tag_test.txt") as f:
            line_num = 0
            while True:
                line = f.readline()
                if not line:
                    break
                q_t, t_t, t = map(lambda x: x.split(" ") ,line[:-1].split("\t"))
                qtw = map_word(q_t)
                qtc = map_char(q_t)
                ttw = map_word(t_t)
                ttc = map_char(t_t)
                tw = map_word(t)
                qtw, qtc, ttw, ttc, tw = map(lambda inner_list: " ".join(map(str ,inner_list)), [qtw, qtc, ttw, ttc, tw])
                o.write("{}\t{}\t{}\t{}\t{}\n".format(qtw, qtc, ttw, ttc, tw))
                line_num += 1
                if line_num % 1000 == 0:
                    print(line_num)
    print("cqtt end")
    with open("cqtn.txt", "w") as o:
        with open("context_query_tag_train.txt") as f:
            line_num = 0
            while True:
                line = f.readline()
                if not line:
                    break
                q_t, t_t, t = map(lambda x: x.split(" ") ,line[:-1].split("\t"))
                qtw = map_word(q_t)
                qtc = map_char(q_t)
                ttw = map_word(t_t)
                ttc = map_char(t_t)
                tw = map_word(t)
                qtw, qtc, ttw, ttc, tw = map(lambda inner_list: " ".join(map(str ,inner_list)), [qtw, qtc, ttw, ttc, tw])
                o.write("{}\t{}\t{}\t{}\t{}\n".format(qtw, qtc, ttw, ttc, tw))
                line_num += 1
                if line_num % 1000 == 0:
                    print(line_num)
    print("cqtn end")

    import pickle
    with open("idx.pkl", "wb") as f:
        pickle.dump({
            "word2idx": word2idx,
            "char2idx": char2idx
        }, f)


数据导出(训练用导出数据函数):

import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np

import pickle
with open("idx.pkl", "rb") as f:
    d = pickle.load(f)
    word2idx = d["word2idx"]
    char2idx = d["char2idx"]

padding_idx_s = str(word2idx["<pad>"])
padding_char_idx = len(char2idx)


def data_generator(gen_type = "n" ,batch_num = 64,
                   T = 100, J = 10, word_length = 10):
    assert gen_type in ["n", "t"]

    def split_word(word_join_str, word_padding):
        input_list = list(map(int, word_join_str.split(" ")))[:word_padding]
        return np.array(input_list + [padding_idx_s] * (word_padding - len(input_list))).astype(np.int32)

    def split_char(char_join_str, char_padding, word_padding):
        def char_nest_process(char_inner_str):
            input_list = char_inner_str.split("_")[:char_padding]
            return input_list + [padding_char_idx] * (char_padding - len(input_list))

        input_split = char_join_str.split(" ")[:word_padding]
        req = np.array(list(map(char_nest_process ,input_split))).astype(np.int32)
        trail = np.full(shape=[word_padding - len(input_split), char_padding], fill_value=padding_char_idx)
        req = np.append(req, trail, axis=0).astype(np.int32)

        return req

    start_idx = 0

    c_word_batch = np.zeros(shape=[batch_num, T]).astype(np.int32)
    c_char_batch = np.zeros(shape=[batch_num, T, word_length]).astype(np.int32)
    q_word_batch = np.zeros(shape=[batch_num, J]).astype(np.int32)
    q_char_batch = np.zeros(shape=[batch_num, J, word_length]).astype(np.int32)
    p1_fake_batch = np.zeros(shape=[batch_num, T]).astype(np.int32)
    p2_fake_batch = np.zeros(shape=[batch_num, T]).astype(np.int32)


    times = 0
    with open("cqt{}.txt".format(gen_type)) as f:
        while True:
            line = f.readline()
            if not line:
                return

            qtw, qtc, ttw, ttc, tw = line[:-1].split("\t")
            tw = split_word(tw, 1)
            qtw = split_word(qtw, T)
            if tw[0] not in qtw:
                continue
            else:
                p1idx = qtw.tolist().index(tw[0])
                if p1idx == (T - 1):
                    continue
            qtc = split_char(qtc, word_length, T)
            ttw = split_word(ttw, J)
            ttc = split_char(ttc, word_length, J)
            p1, p2 = [0] * T, [0] * T
            p1[p1idx] = 1
            p2[p1idx + 1] = 1
            p1 = np.array(p1).astype(np.int32)
            p2 = np.array(p2).astype(np.int32)

            c_word_batch[start_idx] = qtw
            c_char_batch[start_idx] = qtc
            q_word_batch[start_idx] = ttw
            q_char_batch[start_idx] = ttc
            p1_fake_batch[start_idx] = p1
            p2_fake_batch[start_idx] = p2

            start_idx += 1
            if start_idx == batch_num:
                times += 1
                if times == 1e10:
                    return

                yield (c_word_batch,
                        c_char_batch,
                        q_word_batch,
                        q_char_batch,
                        p1_fake_batch,
                        p2_fake_batch)

                start_idx = 0
                c_word_batch = np.zeros(shape=[batch_num, T]).astype(np.int32)
                c_char_batch = np.zeros(shape=[batch_num, T, word_length]).astype(np.int32)
                q_word_batch = np.zeros(shape=[batch_num, J]).astype(np.int32)
                q_char_batch = np.zeros(shape=[batch_num, J, word_length]).astype(np.int32)
                p1_fake_batch = np.zeros(shape=[batch_num, T]).astype(np.int32)
                p2_fake_batch = np.zeros(shape=[batch_num, T]).astype(np.int32)

模型构建:

 
 
class BIDAF(object):
    '''
    char_embed_size: embed size for single character
    char_size: like [a-zA-Z...] element num
    word_length: max single word length
    '''
    def __init__(self, char_embed_size = 10, char_size = padding_char_idx + 1, word_length = 10,
                 word_size = len(word2idx), word_embed_size = 50, T = 200, J = 20, batch_num = 64):

        self.char_embed_size = char_embed_size
        self.char_size = char_size
        self.word_length = word_length

        self.word_size = word_size
        self.word_embed_size = word_embed_size

        self.loss = None
        self.p1_accuracy = None
        self.p2_accuracy = None
        self.accuracy = None

        self.batch_num = batch_num

        with tf.device('/cpu:0'), tf.name_scope("char_embedding"):
            self.char_W = tf.Variable(
                tf.random_uniform([self.char_size, self.char_embed_size], -1.0, 1.0),
                name="char_W")

        with tf.device('/cpu:0'), tf.name_scope("word_embedding"):
            # this layer maybe replaced by w2v or glove in the future
            self.word_W = tf.Variable(
                tf.random_uniform([self.word_size, self.word_embed_size], -1.0, 1.0),
                name="word_W")
        self.word_length = word_length

        self.T = T
        self.J = J

        self.c_char = tf.placeholder(dtype=tf.int32, shape=[None, T, word_length], name="c_char")
        self.c_word = tf.placeholder(dtype=tf.int32, shape=[None, T], name = "c_word")

        self.q_char = tf.placeholder(dtype=tf.int32, shape=[None, J, word_length], name="q_char")
        self.q_word = tf.placeholder(dtype=tf.int32, shape=[None, J], name="q_word")

        self.p1_seq = tf.placeholder(dtype=tf.int32, shape=[None, T], name="p1_seq")
        self.p2_seq = tf.placeholder(dtype=tf.int32, shape=[None, T], name="p2_seq")

        self.model_construct()
        self.opt_construct()

    def model_construct(self):
        # embedding scope
        # Context Embed
        self.c_char_embed_flat = self.char_embed_layer(tf.reshape(self.c_char, [-1, self.word_length]))
        self.c_char_embed = tf.reshape(self.c_char_embed_flat, [-1, self.T, int(self.c_char_embed_flat.get_shape()[-1])], name="c_char_embed")
        self.c_word_embed = self.word_embed_layer(self.c_word)
        self.c_embed = tf.concat([self.c_char_embed, self.c_word_embed], axis=-1, name="c_embed")
        c_embed_last_dim = int(self.c_embed.get_shape()[-1])

        # Query Embed
        self.q_char_embed_flat = self.char_embed_layer(tf.reshape(self.q_char, [-1, self.word_length]))
        self.q_char_embed = tf.reshape(self.q_char_embed_flat, [-1, self.J, int(self.q_char_embed_flat.get_shape()[-1])], name="q_char_embed")
        self.q_word_embed = self.word_embed_layer(self.q_word)
        self.q_embed = tf.concat([self.q_char_embed, self.q_word_embed], axis=-1, name="q_embed")
        q_embed_last_dim = int(self.q_embed.get_shape()[-1])

        # high way scope
        with tf.variable_scope("high_way_layer") as scope:
            self.X = tf.reshape(self.high_way_layer(tf.reshape(self.c_embed, [-1, c_embed_last_dim])), [-1, self.T, c_embed_last_dim])
            scope.reuse_variables()
            self.Q = tf.reshape(self.high_way_layer(tf.reshape(self.q_embed, [-1, q_embed_last_dim])), [-1, self.J, q_embed_last_dim])

        # first bilstm layer
        with tf.variable_scope("first_lstm_layer") as scope:
            self.H = self.first_bilstm_layer(self.X)
            scope.reuse_variables()
            self.U = self.first_bilstm_layer(self.Q)

        # Similarity layer
        # [batch_num, T, J]
        self.S = self.similarity_layer(self.H, self.U)

        # Context_to_query attension
        self.A = tf.nn.softmax(self.S, dim=-1, name="A")
        d = int(self.U.get_shape()[-1]) / 2
        self.U_bar_list = []
        for i in range(self.batch_num):
            A = tf.squeeze(tf.slice(self.A, [i, 0, 0], [1, -1, -1]))
            U = tf.squeeze(tf.slice(self.U, [i, 0, 0], [1, -1, -1]))

            self.U_bar_list.append(tf.expand_dims(tf.matmul(A, U), 0))

        self.U_bar = tf.concat(self.U_bar_list, axis=0)

        # Query_to_context attension
        self.b = tf.nn.softmax(tf.reduce_max(self.S, axis=-1), dim=-1, name="b")

        self.h_bar_list = []
        for i in range(self.batch_num):
            b = tf.slice(self.b, [i, 0], [1, -1])
            H = tf.squeeze(tf.slice(self.H, [i, 0, 0], [1, -1, -1]))
            self.h_bar_list.append(tf.matmul(b, H))


        self.h_bar = tf.concat(self.h_bar_list, axis=1)
        self.H_bar = tf.reshape(tf.tile(self.h_bar, [1 ,self.T]), [-1 ,self.T, int(2 * d)])

        # G layer
        self.G = self.G_layer(self.H, self.H_bar, self.U_bar)

        # second lstm layer
        self.M = self.second_bilstm_layer(self.G)

        # third lstm layer
        self.M2 = self.third_bilstm_layer(self.M)

        # p1 layer
        GM = tf.concat([self.G, self.M], axis = -1)
        Pw1 = tf.get_variable(
            "Pw1",
            shape=[10 * d, 1],
            initializer=tf.contrib.layers.xavier_initializer())
        self.p1 = tf.reshape(tf.matmul(tf.reshape(GM, [-1, int(10 * d)]), Pw1), [-1, self.T], name="p1")

        # p2 layer
        GM2 = tf.concat([self.G, self.M2], axis = -1)
        Pw2 = tf.get_variable(
            "Pw2",
            shape=[10 * d, 1],
            initializer=tf.contrib.layers.xavier_initializer())
        self.p2 = tf.reshape(tf.matmul(tf.reshape(GM2, [-1, int(10 * d)]), Pw2), [-1, self.T], name="p2")

    def opt_construct(self, use_single_p = True):
        self.softmax_p1 = tf.nn.softmax(self.p1)
        self.softmax_p2 = tf.nn.softmax(self.p2)
        self.p1_labels = tf.cast(self.p1_seq, tf.float32)
        self.p2_labels = tf.cast(self.p2_seq, tf.float32)

        p1_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.p1, labels=self.p1_labels), name="p1_loss")
        p2_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.p2, labels=self.p2_labels), name="p2_loss")
        if use_single_p:
            self.loss = p1_loss
        else:
            self.loss = p1_loss + p2_loss

        self.opt = tf.train.AdamOptimizer(learning_rate=0.001)
        self.train_op = self.opt.minimize(self.loss)

        self.pred_p1 = tf.argmax(self.softmax_p1, 1, name="pred_p1")
        self.pred_p2 = tf.argmax(self.softmax_p2, 1, name="pred_p2")

        # Accuracy
        with tf.name_scope("accuracy"):
            correct_pred_p1 = tf.equal(self.pred_p1, tf.argmax(self.p1_labels, 1))
            correct_pred_p2 = tf.equal(self.pred_p2, tf.argmax(self.p2_labels, 1))
            correct_pred = tf.multiply(tf.cast(correct_pred_p1, tf.float32), tf.cast(correct_pred_p2, tf.float32))

            if use_single_p:
                self.p1_accuracy = self.p2_accuracy = self.accuracy = tf.reduce_mean(tf.cast(correct_pred_p1, "float"), name="p1_accuracy")
            else:
                self.p1_accuracy = tf.reduce_mean(tf.cast(correct_pred_p1, "float"), name="p1_accuracy")
                self.p2_accuracy = tf.reduce_mean(tf.cast(correct_pred_p2, "float"), name="p2_accuracy")
                self.accuracy = tf.reduce_mean(correct_pred, name="accuracy")


    def char_embed_layer(self, input_char, num_filters = 3, filter_size = 2):
        with tf.name_scope("char_embedding"):
            embedded_chars = tf.nn.embedding_lookup(self.char_W, input_char)
            embedded_chars_expanded = tf.expand_dims(embedded_chars, -1)

            filter_shape = [filter_size, self.char_embed_size, 1, num_filters]
            W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
            b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b")
            conv = tf.nn.conv2d(
                embedded_chars_expanded,
                W,
                strides=[1, 1, 1, 1],
                padding="VALID",
                name="conv")
            # Apply nonlinearity
            h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
            # Maxpooling over the outputs
            pooled = tf.nn.max_pool(
                h,
                ksize=[1, self.word_length - filter_size + 1, 1, 1],
                strides=[1, 1, 1, 1],
                padding='VALID',
                name="pool")

            pooled_shape = pooled.get_shape()
            total_size = None
            for idx in list(range(len(pooled_shape)))[1:]:
                if total_size is None:
                    total_size = int(pooled_shape[-1 * idx])
                else:
                    total_size *= int(pooled_shape[-1 * idx])

            print("total size of char_embed_layer: {}".format(total_size))

            return tf.reshape(pooled, [-1, int(total_size)])

    def word_embed_layer(self, input_word):
        with tf.name_scope("word_embedding"):
            embedded_word = tf.nn.embedding_lookup(self.word_W, input_word)
            return embedded_word

    def high_way_layer(self, input):
        input_shape = input.get_shape()
        last_dim = int(input_shape[-1])

        HW = tf.get_variable(
            "HW",
            shape=[last_dim, last_dim],
            initializer=tf.contrib.layers.xavier_initializer())
        Hb = tf.get_variable( shape=[last_dim], name="Hb", initializer=tf.contrib.layers.xavier_initializer())

        H = tf.nn.xw_plus_b(input, HW, Hb)

        TW = tf.get_variable(
            "TW",
            shape=[last_dim, last_dim],
            initializer=tf.contrib.layers.xavier_initializer())
        Tb = tf.get_variable( shape=[last_dim], name="Tb", initializer=tf.contrib.layers.xavier_initializer())

        T = tf.nn.xw_plus_b(input, TW, Tb)

        return H * T + input * (1 - T)

    def first_bilstm_layer(self, input):
        input_shape = input.get_shape()
        d = int(input_shape[-1])

        fw_cell = rnn.BasicLSTMCell(d, forget_bias=1., state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
        bw_cell = rnn.BasicLSTMCell(d, forget_bias=1., state_is_tuple=True, reuse=tf.get_variable_scope().reuse)

        rnn_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
            fw_cell, bw_cell, input, scope='first-bi-lstm',
            dtype=tf.float32)

        return tf.concat(rnn_outputs, axis=2, name='first_bilstm_output')

    def similarity_layer(self, H, U):
        d = int(H.get_shape()[-1]) / 2
        h_dim = int(H.get_shape()[-2])
        u_dim = int(U.get_shape()[-2])
        print("similarity layer h_dim: {}, u_dim: {}".format(h_dim, u_dim))

        Sw = tf.get_variable(
            "Sw",
            shape=[6 * d, 1],
            initializer=tf.contrib.layers.xavier_initializer())

        H = tf.reshape(tf.transpose(H, [0, 2, 1]), [-1 ,h_dim])
        U = tf.reshape(tf.transpose(U, [0, 2, 1]), [-1 ,u_dim])
        
        fH = tf.tile(H, [u_dim, 1])
        fU = tf.tile(tf.expand_dims(tf.concat(tf.unstack(U, axis=-1), axis=0), -1), [1, h_dim])
        fHU = fH * fU
        
        fH = tf.reshape(fH, [-1 ,u_dim , int(2 * d) ,h_dim])
        fU = tf.reshape(fU, [-1, u_dim, int(2 * d), h_dim])
        fHU = tf.reshape(fHU, [-1, u_dim, int(2 * d), h_dim])
        
        f = tf.concat([fH, fU, fHU], axis=2)
        f = tf.reshape(tf.transpose(f, [0, 1, 3, 2]), [-1, int(6 * d)])

        # [batch_num, T, J]
        return tf.transpose(tf.reshape(tf.squeeze(tf.matmul(f, Sw)), [-1, u_dim, h_dim]), [0, 2, 1])


    def G_layer(self, H, H_bar, U_bar):
        d = int(H.get_shape()[-1]) / 2

        def g(h, h_bar, u_bar):
            return tf.concat([h, u_bar, h * u_bar, h * h_bar], axis=-1)

        h_list = tf.unstack(H, axis=1, name="h_list")
        h_bar_list = tf.unstack(H_bar, axis=1, name="h_bar_list")
        u_bar_list = tf.unstack(U_bar, axis=1, name="u_bar_list")

        g_list = []
        for idx in range(len(h_list)):
            h = h_list[idx]
            h_bar = h_bar_list[idx]
            u_bar = u_bar_list[idx]
            g_ele = g(h, h_bar, u_bar)
            g_list.append(g_ele)

        return tf.transpose(tf.reshape(tf.concat(g_list, axis = -1), [-1, int(8 * d), self.T]), [0, 2, 1], name="G")

    def second_bilstm_layer(self, input):
        input_shape = input.get_shape()
        d = int(int(input_shape[-1]) / 8)

        with tf.name_scope("second_lstm_layer"):
            fw_cell = rnn.BasicLSTMCell(d, forget_bias=1., state_is_tuple=True)
            bw_cell = rnn.BasicLSTMCell(d, forget_bias=1., state_is_tuple=True)

            rnn_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                fw_cell, bw_cell, input, scope='second-bi-lstm',
                dtype=tf.float32)

            return tf.concat(rnn_outputs, axis=2, name='second_bilstm_output')

    def third_bilstm_layer(self, input):
        input_shape = input.get_shape()
        d = int(int(input_shape[-1]) / 2)

        with tf.name_scope("third_lstm_layer"):
            fw_cell = rnn.BasicLSTMCell(d, forget_bias=1., state_is_tuple=True)
            bw_cell = rnn.BasicLSTMCell(d, forget_bias=1., state_is_tuple=True)

            rnn_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                fw_cell, bw_cell, input, scope='third-bi-lstm',
                dtype=tf.float32)

            return tf.concat(rnn_outputs, axis=2, name='third_bilstm_output')


    @staticmethod
    def train():
        from time import time
        bidaf_ext = BIDAF()
        print("model construct end :")

        tg = data_generator(gen_type="n")
        ttg = data_generator(gen_type="t")
        num_epoch = 100

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            for now_epoch in range(num_epoch):
                step = 0
                while True:
                    try:
                        c_word,c_char,q_word,q_char,p1_fake,p2_fake = tg.__next__()
                    except:
                        tg = data_generator(gen_type="n")
                        ttg = data_generator(gen_type="t")
                        print("epoch {} end".format(now_epoch))
                        break

                    _ , \
                    loss, \
                    p1_accuracy, \
                    p2_accuracy, \
                    accuracy,\
                    H, U_bar, H_bar \
                        = sess.run([
                        bidaf_ext.train_op,
                        bidaf_ext.loss,
                        bidaf_ext.p1_accuracy,
                        bidaf_ext.p2_accuracy,
                        bidaf_ext.accuracy,

                        bidaf_ext.H,
                        bidaf_ext.U_bar,
                        bidaf_ext.H_bar,
                    ],
                        feed_dict={
                            bidaf_ext.c_word: c_word,
                            bidaf_ext.c_char: c_char,
                            bidaf_ext.q_word: q_word,
                            bidaf_ext.q_char: q_char,
                            bidaf_ext.p1_seq: p1_fake,
                            bidaf_ext.p2_seq: p2_fake
                        })

                    if step % 10 == 0:
                        print("train loss :{} p1_accuracy :{} p2_accuracy :{} accuracy :{}".
                              format(loss, p1_accuracy, p2_accuracy, accuracy))

                    if step % 100 == 0:
                        try:
                            c_word,c_char,q_word,q_char,p1_fake,p2_fake = ttg.__next__()
                        except:
                            ttg = data_generator(gen_type="t")
                            c_word,c_char,q_word,q_char,p1_fake,p2_fake = ttg.__next__()


                        loss, \
                        p1_accuracy, \
                        p2_accuracy, \
                        accuracy  = sess.run([
                            bidaf_ext.loss,
                            bidaf_ext.p1_accuracy,
                            bidaf_ext.p2_accuracy,
                            bidaf_ext.accuracy,
                        ],
                            feed_dict={
                                bidaf_ext.c_word: c_word,
                                bidaf_ext.c_char: c_char,
                                bidaf_ext.q_word: q_word,
                                bidaf_ext.q_char: q_char,
                                bidaf_ext.p1_seq: p1_fake,
                                bidaf_ext.p2_seq: p2_fake
                            })
                        print("test loss :{} p1_accuracy :{} p2_accuracy :{} accuracy :{}".
                              format(loss, p1_accuracy, p2_accuracy, accuracy))

                    step += 1


这里由于较长序列对性能影响较大,限制了较小的TJ

use_single_p 可以用来控制是否考虑end indice

 

测试集的accuracy在各个epoch的表现:

test loss :0.6456390023231506 p1_accuracy :0.017999999225139618
test loss :0.050495319068431854 p1_accuracy :0.057999998331069946
epoch 0 end
test loss :0.0502440445125103 p1_accuracy :0.03999999910593033
test loss :0.05010032281279564 p1_accuracy :0.06199999898672104
epoch 1 end
test loss :0.05005083978176117 p1_accuracy :0.035999998450279236
test loss :0.04990968108177185 p1_accuracy :0.06599999964237213
epoch 2 end
test loss :0.04992210119962692 p1_accuracy :0.04399999976158142
test loss :0.04977176710963249 p1_accuracy :0.07199999690055847
epoch 3 end
test loss :0.04978775978088379 p1_accuracy :0.03400000184774399
test loss :0.049564428627491 p1_accuracy :0.07199999690055847
epoch 4 end
test loss :0.049547359347343445 p1_accuracy :0.057999998331069946
test loss :0.04935314133763313 p1_accuracy :0.09799999743700027
epoch 5 end
test loss :0.04918454959988594 p1_accuracy :0.06800000369548798
test loss :0.04861310124397278 p1_accuracy :0.14000000059604645
epoch 6 end
test loss :0.04811424762010574 p1_accuracy :0.10999999940395355
test loss :0.04380807653069496 p1_accuracy :0.22200000286102295
epoch 7 end
test loss :0.040549907833337784 p1_accuracy :0.24799999594688416
test loss :0.034186024218797684 p1_accuracy :0.414000004529953
epoch 8 end
test loss :0.032501496374607086 p1_accuracy :0.43799999356269836
test loss :0.02821703627705574 p1_accuracy :0.5040000081062317
epoch 9 end
test loss :0.02707063965499401 p1_accuracy :0.5440000295639038
test loss :0.024162035435438156 p1_accuracy :0.5879999995231628
epoch 10 end
test loss :0.02386888675391674 p1_accuracy :0.5979999899864197
test loss :0.021960947662591934 p1_accuracy :0.6320000290870667
epoch 11 end
test loss :0.022267578169703484 p1_accuracy :0.6240000128746033
test loss :0.0203940998762846 p1_accuracy :0.6620000004768372
epoch 12 end
test loss :0.021203691139817238 p1_accuracy :0.6380000114440918
test loss :0.019260596483945847 p1_accuracy :0.6779999732971191
epoch 13 end
test loss :0.020675132051110268 p1_accuracy :0.6499999761581421
test loss :0.018571924418210983 p1_accuracy :0.6859999895095825
epoch 14 end
test loss :0.020104004070162773 p1_accuracy :0.6639999747276306
test loss :0.01809442788362503 p1_accuracy :0.6940000057220459
epoch 15 end
test loss :0.019806142896413803 p1_accuracy :0.6620000004768372
test loss :0.018006963655352592 p1_accuracy :0.6980000138282776
epoch 16 end
test loss :0.019820837303996086 p1_accuracy :0.6639999747276306
test loss :0.018198983743786812 p1_accuracy :0.6959999799728394
epoch 17 end
test loss :0.01980511285364628 p1_accuracy :0.6679999828338623
test loss :0.018297001719474792 p1_accuracy :0.6919999718666077
epoch 18 end
test loss :0.0200297012925148 p1_accuracy :0.6639999747276306
test loss :0.0184195376932621 p1_accuracy :0.699999988079071
epoch 19 end
test loss :0.020458657294511795 p1_accuracy :0.6620000004768372
test loss :0.018728474155068398 p1_accuracy :0.699999988079071
epoch 20 end
test loss :0.020740188658237457 p1_accuracy :0.6620000004768372
test loss :0.01889195665717125 p1_accuracy :0.7039999961853027
epoch 21 end
test loss :0.021275488659739494 p1_accuracy :0.6660000085830688
test loss :0.019528238102793694 p1_accuracy :0.6959999799728394
epoch 22 end
test loss :0.021669382229447365 p1_accuracy :0.671999990940094
test loss :0.019826726987957954 p1_accuracy :0.6819999814033508
epoch 23 end
test loss :0.022092627361416817 p1_accuracy :0.656000018119812
test loss :0.020767534151673317 p1_accuracy :0.6819999814033508
epoch 24 end
test loss :0.022723300382494926 p1_accuracy :0.6579999923706055
test loss :0.021448861807584763 p1_accuracy :0.6800000071525574
epoch 25 end
test loss :0.023822899907827377 p1_accuracy :0.628000020980835
test loss :0.0220272745937109 p1_accuracy :0.671999990940094
epoch 26 end
test loss :0.02495754137635231 p1_accuracy :0.6299999952316284




猜你喜欢

转载自blog.csdn.net/sinat_30665603/article/details/79888354
今日推荐