R-NET简单实现

import tensorflow as tf
from tensorflow.contrib import rnn
import json
import numpy as np

class R_NET:
    
    def __init__(self, batch_size=10, p_length=300, q_length=30, emb_dim=300):
        self.batch_size = batch_size
        self.p_length = p_length
        self.q_length = q_length
        self.emb_dim = emb_dim
        self.state_size = 75
        self.span_length = 20
        
    def mat_weight_mul(self, mat, weight):
        # [batch_size, n, m] * [m, p] = [batch_size, n, p]
        mat_shape = mat.get_shape().as_list()
        weight_shape = weight.get_shape().as_list()
        assert(mat_shape[-1] == weight_shape[0])
        mat_reshape = tf.reshape(mat, [-1, mat_shape[-1]]) # [batch_size * n, m]
        mul = tf.matmul(mat_reshape, weight) # [batch_size * n, p]
        return tf.reshape(mul, [-1, mat_shape[1], weight_shape[-1]])
    
    def build_model(self):
        
        # placeholders
        paragraph = tf.placeholder(tf.float32, [self.batch_size, self.p_length, self.emb_dim])
        question = tf.placeholder(tf.float32, [self.batch_size, self.q_length, self.emb_dim])
        answer_si = tf.placeholder(tf.float32, [self.batch_size, self.p_length])
        answer_ei = tf.placeholder(tf.float32, [self.batch_size, self.p_length])
        
        # encoding
        unstack_question = tf.unstack(question, self.q_length, 1)
        unstack_paragraph = tf.unstack(paragraph, self.p_length, 1)
        with tf.variable_scope('encoding') as scope:
            fw_cell = rnn.BasicLSTMCell(self.state_size)
            bw_cell = rnn.BasicLSTMCell(self.state_size)
            
            q_enc, _, _ = rnn.static_bidirectional_rnn(fw_cell, bw_cell, unstack_question, dtype=tf.float32)
            tf.get_variable_scope().reuse_variables()
            p_enc, _, _ = rnn.static_bidirectional_rnn(fw_cell, bw_cell, unstack_paragraph, dtype=tf.float32)
            
            u_Q = tf.stack(q_enc, 1) # [batch_size, q_length, 2 * state_size]
            u_P = tf.stack(p_enc, 1) # [batch_size, p_length, 2 * state_size]
            
        # question-paragraph match
        v_P = []
        with tf.variable_scope('QP_match') as scope:
            W_uQ = tf.Variable(tf.truncated_normal([2*self.state_size, self.state_size]))
            W_uP = tf.Variable(tf.truncated_normal([2*self.state_size, self.state_size]))
            W_vP = tf.Variable(tf.truncated_normal([self.state_size, self.state_size]))
            W_g_QP = tf.Variable(tf.truncated_normal([4*self.state_size, 4*self.state_size]))
            B_v_QP = tf.Variable(tf.truncated_normal([self.state_size]))
            
            qp_match_cell = rnn.BasicLSTMCell(self.state_size)
            qp_match_state = qp_match_cell.zero_state(self.batch_size, dtype=tf.float32)
            
            for t in range(self.p_length):
                
                # c_t
                W_uQ_u_Q = self.mat_weight_mul(u_Q, W_uQ)
                u_tP = tf.concat( [tf.reshape(u_P[:, t, :], [self.batch_size, 1, -1])] * self.q_length, 1)
                W_uP_u_tP = self.mat_weight_mul(u_tP , W_uP)
                
                if t == 0:
                    tanh = tf.tanh(W_uQ_u_Q + W_uP_u_tP)
                else:
                    v_t1P = tf.concat( [tf.reshape(v_P[t-1], [self.batch_size, 1, -1])] * self.q_length, 1)
                    W_vP_v_t1P = self.mat_weight_mul(v_t1P, W_vP)
                    tanh = tf.tanh(W_uQ_u_Q + W_uP_u_tP + W_vP_v_t1P)
                    
                s_t = tf.squeeze(self.mat_weight_mul(tanh, tf.reshape(B_v_QP, [-1, 1])))
                a_t = tf.nn.softmax(s_t, 1)
                tiled_a_t = tf.concat( [tf.reshape(a_t, [self.batch_size, -1, 1])] * 2 * self.state_size , 2)
                c_t = tf.reduce_sum( tf.multiply(tiled_a_t, u_Q) , 1) # [batch_size, 2 * state_size]
                
                # gate
                u_tP_c_t = tf.expand_dims( tf.concat( [tf.squeeze(u_P[:, t, :]), c_t], 1), 1)
                g_t = tf.sigmoid( self.mat_weight_mul(u_tP_c_t, W_g_QP) )
                u_tP_c_t_star = tf.squeeze(tf.multiply(u_tP_c_t, g_t))
                
                qp_match_output, qp_match_state = qp_match_cell(u_tP_c_t_star, qp_match_state)
                v_P.append(qp_match_output)
                
        v_P = tf.stack(v_P, 1) # [batch_size, state_size]
        
        # self-match
        SM_star = []
        with tf.variable_scope('Self_match') as scope:
            W_smP1 = tf.Variable(tf.truncated_normal([self.state_size, self.state_size]))
            W_smP2 = tf.Variable(tf.truncated_normal([self.state_size, self.state_size]))
            W_g_SM = tf.Variable(tf.truncated_normal([2*self.state_size, 2*self.state_size]))
            B_v_SM = tf.Variable(tf.truncated_normal([self.state_size]))
            
            for t in range(self.p_length):
                
                # s_t
                W_p1_v_P = self.mat_weight_mul(v_P, W_smP1) # [batch_size, p_length, state_size]
                tiled_v_tP = tf.concat( [tf.reshape(v_P[:, t, :], [self.batch_size, 1, -1])] * self.p_length, 1)
                W_p2_v_tP = self.mat_weight_mul(tiled_v_tP , W_smP2)
                tanh = tf.tanh(W_p1_v_P + W_p2_v_tP)
            
                s_t = tf.squeeze(self.mat_weight_mul(tanh, tf.reshape(B_v_SM, [-1, 1])))
                a_t = tf.nn.softmax(s_t, 1)
                tiled_a_t = tf.concat( [tf.reshape(a_t, [self.batch_size, -1, 1])] * self.state_size , 2)
                c_t = tf.reduce_sum( tf.multiply(tiled_a_t, v_P) , 1)
            
                # gate
                v_tP_c_t = tf.expand_dims( tf.concat( [tf.squeeze(v_P[:, t, :]), c_t], 1), 1)
                g_t = tf.sigmoid( self.mat_weight_mul(v_tP_c_t, W_g_SM) )
                v_tP_c_t_star = tf.squeeze(tf.multiply(v_tP_c_t, g_t))
                
                SM_star.append(v_tP_c_t_star)
            
            SM_star = tf.stack(SM_star, 1)
            unstacked_SM_star = tf.unstack(SM_star, self.p_length, 1)
            
            SM_fw_cell = rnn.BasicLSTMCell(self.state_size)
            SM_bw_cell = rnn.BasicLSTMCell(self.state_size)
            
            SM_outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(SM_fw_cell, SM_bw_cell, unstacked_SM_star, dtype=tf.float32)
            h_P = tf.stack(SM_outputs, 1) # [batch_size, p_length, 2 * state_size]
        
        # output
        p = [None for _ in range(2)]
        with tf.variable_scope('Ans_ptr') as scope:
            W_ruQ = tf.Variable(tf.truncated_normal([2*self.state_size, 2*self.state_size]))
            W_vQ = tf.Variable(tf.truncated_normal([self.state_size, 2*self.state_size]))
            W_VrQ = tf.Variable(tf.truncated_normal([self.q_length, self.state_size]))
            B_v_rQ = tf.Variable(tf.truncated_normal([2*self.state_size]))
            
            # r_Q
            W_ruQ_u_Q = self.mat_weight_mul(u_Q, W_ruQ) # [batch_size, q_length, 2 * state_size]
            W_vQ_V_rQ = tf.matmul(W_VrQ, W_vQ)
            W_vQ_V_rQ = tf.stack([W_vQ_V_rQ]*self.batch_size, 0) # stack -> [batch_size, q_length, 2 * state_size]
            tanh = tf.tanh(W_ruQ_u_Q + W_vQ_V_rQ)
            s_t = tf.squeeze(self.mat_weight_mul(tanh, tf.reshape(B_v_rQ, [-1, 1])))
            a_t = tf.nn.softmax(s_t, 1)
            tiled_a_t = tf.concat( [tf.reshape(a_t, [self.batch_size, -1, 1])] * 2 * self.state_size , 2)
            r_Q = tf.reduce_sum( tf.multiply(tiled_a_t, u_Q) , 1) # [batch_size, 2 * state_size]
        
            # answer pointer
            W_hP = tf.Variable(tf.truncated_normal([2*self.state_size, self.state_size]))
            W_ha = tf.Variable(tf.truncated_normal([2*self.state_size, self.state_size]))
            B_v_ap = tf.Variable(tf.truncated_normal([self.state_size]))
            
            h_a = None
            
            ans_ptr_cell = rnn.BasicLSTMCell(2*self.state_size)
            ans_ptr_cell_state = ans_ptr_cell.zero_state(self.batch_size, dtype=tf.float32)
            for t in range(2):
                W_hP_h_P = self.mat_weight_mul(h_P, W_hP)
                
                if t == 0:
                    h_t1a = r_Q
                else:
                    h_t1a = h_a
        
                tiled_h_t1a = tf.concat( [tf.reshape(h_t1a, [self.batch_size, 1, -1])] * self.p_length, 1)
                W_ha_h_t1a = self.mat_weight_mul(tiled_h_t1a , W_ha)
                tanh = tf.tanh(W_hP_h_P + W_ha_h_t1a)
                s_t = tf.squeeze(self.mat_weight_mul(tanh, tf.reshape(B_v_ap, [-1, 1])))
                a_t = tf.nn.softmax(s_t, 1)
                tiled_a_t = tf.concat( [tf.reshape(a_t, [self.batch_size, -1, 1])] * 2 * self.state_size , 2)
                c_t = tf.reduce_sum( tf.multiply(tiled_a_t, h_P) , 1) # [batch_size, 2 * state_size]
            
                p[t] = a_t
                
                if t == 0:
                    h_a, ans_ptr_cell_state = ans_ptr_cell(c_t, (ans_ptr_cell_state.c, r_Q) )
                else:
                    pass
                    
        # loss
        p1 = p[0]
        p2 = p[1]
        
        answer_si_idx = tf.cast(tf.argmax(answer_si, 1), tf.int32)
        answer_ei_idx = tf.cast(tf.argmax(answer_ei, 1), tf.int32)
        
        batch_idx = tf.reshape(tf.range(0, self.batch_size), [-1,1])
        answer_si_re = tf.reshape(answer_si_idx, [-1,1])
        batch_idx_si = tf.concat([batch_idx, answer_si_re],1)
        answer_ei_re = tf.reshape(answer_ei_idx, [-1,1])
        batch_idx_ei = tf.concat([batch_idx, answer_ei_re],1)
    
        log_prob = tf.multiply(tf.gather_nd(p1, batch_idx_si), tf.gather_nd(p2, batch_idx_ei))
        loss = -tf.reduce_sum(tf.log(log_prob+0.0000001))
        
        # accuracy
        prob = []
        search_range = self.p_length - self.span_length
        for i in range(search_range):
            for j in range(self.span_length):
                prob.append(tf.multiply(p1[:, i], p2[:, i+j]))
        prob = tf.stack(prob, axis = 1)
        argmax_idx = tf.argmax(prob, axis=1)
        
        pred_si = argmax_idx / self.span_length
        pred_ei = pred_si + tf.cast(tf.mod(argmax_idx , self.span_length), tf.float64)
        
        correct = tf.logical_and(tf.equal(tf.cast(pred_si, tf.int64), tf.cast(answer_si_idx, tf.int64)), 
                                 tf.equal(tf.cast(pred_ei, tf.int64), tf.cast(answer_ei_idx, tf.int64)))
        accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
        
        input_tensors = {
            'p': paragraph,
            'q': question,
            'a_si':answer_si,
            'a_ei':answer_ei,
        }
        
        return input_tensors, loss, accuracy, pred_si, pred_ei
        
        
class DataProcessor:
    def __init__(self, batch_size=10, p_length=300, q_length=30, emb_dim=300):
        self.batch_size = batch_size
        self.p_length = p_length
        self.q_length = q_length
        self.emb_dim = emb_dim
        
        self.data = self.load_data('paragraph.json')
        self.vec = self.load_vec('wiki.zh.vec')
        
    def load_data(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
        
    def load_vec(self, path):
        vec = {}
        with open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
            n, d = map(int, f.readline().split())
            for line in f:
                tokens = line.rstrip().split(' ')
                vec[tokens[0]] = list(map(float, tokens[1:]))
        return vec
        
    def word2vec(self, word):
        if word in self.vec:
            return self.vec[word]
        else:
            return np.zeros((self.emb_dim))
        
    def get_emb(self, text, length):
        emb = np.zeros((length, self.emb_dim))
        for i in range(min(length, len(text))):
            emb[i] = self.word2vec(text[i])
            
        return emb
        
    def get_num_samples(self):
        num = 0
        for d in self.data:
            num += len(d['qas'])
            
        return num
        
    def gen_test(self):
        while True:
            paragraph = np.zeros((2, self.p_length, self.emb_dim))
            question = np.zeros((2, self.q_length, self.emb_dim))
            answer_si = np.zeros((2, self.p_length))
            answer_ei = np.zeros((2, self.p_length))
                            
            p = input('paragraph:')
            q = input('question:')
            
            paragraph[0] = self.get_emb(p, self.p_length)
            paragraph[1] = self.get_emb(p, self.p_length)
            question[0] = self.get_emb(q, self.q_length)
            question[1] = self.get_emb(q, self.q_length)
            
            yield {
                'paragraph_text': p,
                'question_text': q,
                'paragraph': paragraph,
                'question': question,
                'answer_si': answer_si,
                'answer_ei': answer_ei
            }
        
    def gen_training_batch(self):
        while True:
            count = 0
            for d in self.data:
                p = d['paragraph']
                for qa in d['qas']:
                    if count == 0:
                        paragraph = np.zeros((self.batch_size, self.p_length, self.emb_dim))
                        question = np.zeros((self.batch_size, self.q_length, self.emb_dim))
                        answer_si = np.zeros((self.batch_size, self.p_length))
                        answer_ei = np.zeros((self.batch_size, self.p_length))
            
                    q = qa['question']
                    si = qa['start']
                    ei = qa['end']
                    
                    if si >= self.p_length or ei >= self.p_length:
                        continue
                    
                    paragraph[count] = self.get_emb(p, self.p_length)
                    question[count] = self.get_emb(q, self.q_length)
                    answer_si[count][si] = 1.0
                    answer_ei[count][ei] = 1.0
                    
                    count += 1
                    
                    if count % self.batch_size == 0:
                        yield {
                            'paragraph_text': p,
                            'question_text': q,
                            'si': si,
                            'ei': ei,
                            'paragraph': paragraph,
                            'question': question,
                            'answer_si': answer_si,
                            'answer_ei': answer_ei
                        }
                        
                        count = 0
            
        
def run(start_epoch=0):
    batch_size = 10
    p_length = 300
    q_length = 30
    emb_dim = 300
    
    print('load data...')
    dp = DataProcessor(batch_size, p_length, q_length, emb_dim)
    
    num_samples = dp.get_num_samples()
    num_batches = int(num_samples / batch_size)
    num_epochs = 200
    
    print('build model...')
    rnet_model = R_NET(batch_size, p_length, q_length, emb_dim)
    input_tensors, loss, acc, pred_si, pred_ei = rnet_model.build_model()
    train_op = tf.train.AdadeltaOptimizer(1.0, rho=0.95, epsilon=1e-06,).minimize(loss)
    
    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)
    if start_epoch > 0:
        saver.restore(sess, "model/rnet_epoch_{}.ckpt".format(start_epoch-1))
    else:
        init = tf.global_variables_initializer()
        sess.run(init)
    
    print('training...')
    gen = dp.gen_training_batch()
    for ei in range(start_epoch, num_epochs):
        for bi in range(num_batches):
            tensor_dict = next(gen)
            feed_dict = {
                input_tensors['p']: tensor_dict['paragraph'],
                input_tensors['q']: tensor_dict['question'],
                input_tensors['a_si']: tensor_dict['answer_si'],
                input_tensors['a_ei']: tensor_dict['answer_ei'],
            }
            _, loss_value, accuracy, predictions_si, predictions_ei = sess.run(
                [train_op, loss, acc, pred_si, pred_ei], feed_dict=feed_dict)

            print("{} epoch {} batch, Loss:{:.2f}, Acc:{:.2f}".format(ei, bi, loss_value, accuracy))
        
        if ei % 50 == 0:
            saver.save(sess, "model/rnet_epoch_{}.ckpt".format(ei))

def test():
    p_length = 300
    q_length = 30
    emb_dim = 300
    
    print('load data...')
    dp = DataProcessor(2, p_length, q_length, emb_dim)
    
    print('build model...')
    rnet_model = R_NET(2, p_length, q_length, emb_dim)
    input_tensors, loss, acc, pred_si, pred_ei = rnet_model.build_model()
    
    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)
    saver.restore(sess, "model/rnet_epoch_{}.ckpt".format(150))
    
    gen = dp.gen_test()
    while True:
        tensor_dict = next(gen)
        feed_dict = {
            input_tensors['p']: tensor_dict['paragraph'],
            input_tensors['q']: tensor_dict['question'],
            input_tensors['a_si']: tensor_dict['answer_si'],
            input_tensors['a_ei']: tensor_dict['answer_ei'],
        }
        predictions_si, predictions_ei = sess.run([pred_si, pred_ei], feed_dict=feed_dict)
            
        p = tensor_dict['paragraph_text']
        q = tensor_dict['question_text']
        
        print('paragraph: '+p)
        print('question: '+q)
        print('answer: '+p[int(predictions_si[0]):int(predictions_ei[0]+1)])

            
if __name__ == '__main__':
    run()
    test()

数据:

paragraph.json:形如[{"paragraph":"段落文字", "qas":[{"question":"问题", "start": 答案起始位置, "end": 答案结束为止}]}]

wiki.zh.vec:https://fasttext.cc/docs/en/pretrained-vectors.html 中文向量字典

参考:

1、https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf

2、https://github.com/unilight/R-NET-in-Tensorflow

猜你喜欢

转载自my.oschina.net/airxiechao/blog/1808478
今日推荐