Towards Better Text Understanding and Retrieval through Kernel Entity Salience Modeling 论文阅读及实现

论文链接:

https://arxiv.org/abs/1805.01334

问题提出:

对于一个文本,可以抽取到很多有用的实体,如何给实体对于该文章的重要程度打分?

应用背景:

         如果能够基于实体关于文本的重要程度进行打分,那么提取用户query的实体对于候选文本进行打分即可完成搜索引擎场景下候选文本的ranking.

模型方案:

         本文计算实体与文本相似度的方案为将实体对于文本是否重要(Salience or not)进行打分。输入数据集可以理解为实体候选集(每一个实体有是否对于文本为salience的标注)与对应文本的一一对应。

         提出的方案是除了利用实体及对应文本的信息外,还加入了实体的描述信息。将其描述embedding用卷积神经网络提取提取特征后fuse实体emedding得到实体Knowledge Enriched Embedding作为实体的dense表征。
           这种方案对于有专业知识库做专业实体搜索排序相关推荐有一定的助益。

        对于每一个实体计算其与候选集中所有实体的相似度(这里的计算相似度为计算二者embedding的COS,之后映射到rbf核空间并对不同的核映射值进行加总——类似于非参数核密度近似、在这里即为类比为语义距离的近似程度,后面的相似度也一样),再计算该实体与候选文本中所有token embedding的相似度,将这些相似度fuse成一个向量,将这个向量变换到1维得到实体与文本的相关度得分。

记下面为得分表示:


记如下salience标签表示


则可以如下定义损失:


经过优化后就可以使用实体关于文本的打分函数f来定义query关于文本的得分:


其中q为query,d为document。

下面给出训练打分函数的实现

数据说明:

         本文采用原论文的第二个数据集Semantic Scholar corpus。其为json化的对象

数据连接:http://labs.semanticscholar.org/corpus/

按文中所述,entities域中为所有候选实体,实体是否出现在title域作为判定salience的准则,用于实体匹配打分的document为paperAbstract域。

实体的描述采用调用WiKIData api的方法(具体实现见下面爬虫代码)。

下面给出实现代码:

下载解压数据集,置于D:\download\corpus,下面给出scrapy调取WIKIData api的代码。

Item定义:

import scrapy

class WikidatacrawlerItem(scrapy.Item):
    # define the fields for your item here like:
    # name = scrapy.Field()
    entity = scrapy.Field()
    description = scrapy.Field()

爬虫主体:

from scrapy.http import Request
from scrapy.spiders import Spider
from WikiDataCrawler.items import WikidatacrawlerItem

import os
import json

class file_Iter(object):
    def __init__(self, rootDir):
        req = []
        list_dirs = os.walk(rootDir)
        for root, dirs, files in list_dirs:
            for d in dirs:
                pass

            for f in files:
                req.append(os.path.join(root, f))
        self.all_req_pos_files = req

    def generate_entity(self, json_line):
        json_obj = json.loads(json_line, encoding="utf-8")
        if json_obj.get("entities"):
            for ele in json_obj["entities"]:
                yield  ele

    def __iter__(self):
        for file in self.all_req_pos_files:
            print("begin read {}".format(file) + "-" * 100)
            with open(file, "r", encoding="utf-8") as f:
                while True:
                    line = f.readline().strip()
                    if line:
                        for ele in self.generate_entity(line):
                            yield ele
                    else:
                        break

class DataSpider(Spider):
    name = "dataSpider"
    description_url_format = "https://www.wikidata.org/w/api.php?action=wbsearchentities&search={}&language=en&limit=20&format=json"

    def data_loader(self):
        print("call data_loader :")
        self.file_Iter_ext = file_Iter("D:\download\corpus")

    def start_requests(self):
        self.data_loader()
        for entity_str in self.file_Iter_ext:
            yield Request(url = self.description_url_format.format(entity_str), meta={"entity": entity_str})

    def parse(self, response):
        if response.body:
            json_obj = json.loads(response.body, encoding="utf-8")
            if json_obj.get("search"):
                first_item = json_obj["search"][0]
                if first_item.get("description"):
                    item = WikidatacrawlerItem()
                    item["entity"] = response.meta["entity"]
                    item["description"] = first_item["description"]
                    yield item

pipeline 本地化:

import pandas as pd

class WikidatacrawlerPipeline(object):
    def open_spider(self, spider):
        self.df_chunk_size = 1e2
        self.temp_list = []
        self.df_dir = "ed.csv"
        self.times = 0

    def process_item(self, item, spider):
        self.temp_list.append([item["entity"], item["description"]])
        self.serialize()
        return item

    def serialize(self):
        if len(self.temp_list) >= self.df_chunk_size:
            df = pd.DataFrame(self.temp_list, columns=["entity", "description"])
            if self.times:
                with open(self.df_dir, "a", encoding="utf-8") as f:
                    df.to_csv(f, header=False)
            else:
                df.to_csv(self.df_dir, header=True, encoding="utf-8")
            self.temp_list = []
            self.times += 1
            print("write times :{}".format(self.times))

    def close_spider(self, spider):
        print("will closed")
        self.serialize()

如上得到实体描述后,进行数据处理:

import pandas as pd
from collections import Counter
from dl_text import dl
import nltk
import re
import json
import os
import pickle

class file_Iter(object):
    def __init__(self, rootDir):
        req = []
        list_dirs = os.walk(rootDir)
        for root, dirs, files in list_dirs:
            for d in dirs:
                pass

            for f in files:
                req.append(os.path.join(root, f))
        self.all_req_pos_files = req

    def generate_entity(self, json_line):
        json_obj = json.loads(json_line, encoding="utf-8")
        if json_obj.get("entities"):
            for ele in json_obj["entities"]:
                yield  ele

    def generate_corpus_req(self, json_line):
        json_obj = json.loads(json_line, encoding="utf-8")
        if json_obj.get("entities") and json_obj.get("paperAbstract") and json_obj.get("title"):
            candidate_entities = json_obj["entities"]
            lower_title = json_obj["title"].lower()

            ture_target_entities = list(filter(lambda x: x.lower() in lower_title,candidate_entities))
            false_target_entities = list(set(candidate_entities).difference(set(ture_target_entities)))
            document_pos = clean_pos(json_obj["paperAbstract"])
            yield [ture_target_entities, false_target_entities, document_pos]

    def __iter__(self, entities_only = False):
        for file in self.all_req_pos_files:
            print("begin read {}".format(file) + "-" * 100)
            with open(file, "r", encoding="utf-8") as f:
                while True:
                    line = f.readline().strip()
                    if line:
                        if entities_only:
                            for ele in self.generate_entity(line):
                                yield ele
                        else:
                            for ele in self.generate_corpus_req(line):
                                yield ele
                    else:
                        break

def varify_token(token):
    return re.match("[a-z]+", token)
def clean_pos(text):
    return list(filter(varify_token ,nltk.tokenize.word_tokenize(dl.clean(text).lower())))

def process_corpos():
    temp_list = []
    df_chunk_size = int(1e4)
    times = 0
    df_dir = "corpus_df.csv"

    def serialize(times, temp_list, final = False):
        if final:
            df = pd.DataFrame(temp_list, columns=["true_e", "false_e", "doc_pos"])
            if times:
                with open(df_dir, "a", encoding="utf-8") as f:
                    df.to_csv(f, header=False)
            return

        if len(temp_list) >= df_chunk_size:
            df = pd.DataFrame(temp_list, columns=["true_e", "false_e", "doc_pos"])
            if times:
                with open(df_dir, "a", encoding="utf-8") as f:
                    df.to_csv(f, header=False)
            else:
                df.to_csv(df_dir, header=True, encoding="utf-8")
            temp_list = []
            times += 1
            print("write times :{}".format(times))
        return times, temp_list

    file_Iter_ext = file_Iter("D:\download\corpus")
    for l3 in file_Iter_ext:
        temp_list.append(l3)
        times, temp_list = serialize(times, temp_list)
    serialize(times, temp_list, final = True)

def entity_and_word_idx(entity_num_filter_size = 10):
    word_cnt = Counter()
    entity_cnt = Counter()

    entity_num_df = pd.read_csv(r"D:\Coding\python\KESM\data_process\corpus_with_entity_num.csv", encoding="utf-8")
    req_entity_num_df = entity_num_df[entity_num_df["entity_intersect_count"] > entity_num_filter_size]
    del entity_num_df
    req_entity_num_df.to_csv("req_corpus_df.csv", encoding="utf-8")
    print("serlize req_corpus_df end")

    for ridx, r in req_entity_num_df.iterrows():
        entities = eval(r["true_e"]) + eval(r["false_e"])
        doc_pos = eval(r["doc_pos"])
        entity_cnt.update(entities)
        word_cnt.update(doc_pos)
        if ridx % 10000 == 0:
            print("ridx :{}".format(ridx))

    entity_df = pd.DataFrame.from_csv(r"D:\Coding\python\KESM\data_process\entity_df.csv", encoding="utf-8")
    for ridx, r in entity_df.iterrows():
        #entity = r["entity"]
        desp_token = eval(r["desp_token"])
        word_cnt.update(desp_token)
        if ridx % 10000 == 0:
            print("ridx :{}".format(ridx))

    with open("cnt.pkl", "wb") as f:
        pickle.dump({
            "word_cnt": word_cnt,
            "entity_cnt": entity_cnt
        }, f)
    print("cnt serialize end")

def idx_file():
    # full encode and mapping it by condition dynamicly.
    req_entity_num_df = pd.read_csv(r"D:\Coding\python\KESM\data_process\req_corpus_df.csv", encoding="utf-8")
    from collections import defaultdict
    with open("cnt.pkl", "rb") as f:
        cnt_dict = pickle.load(f)
        word_cnt = cnt_dict["word_cnt"]
        entity_cnt = cnt_dict["entity_cnt"]

    word2idx = dict((t2[0], idx) for idx ,t2 in enumerate(sorted(word_cnt.items(), key = lambda x: -1 * x[1])))
    word2idx = dict(list(word2idx.items()) + [("<unk>", len(word2idx))])
    entity2idx = dict((t2[0], idx) for idx ,t2 in enumerate(sorted(entity_cnt.items(), key = lambda x: -1 * x[1])))
    entity2idx = dict(list(entity2idx.items()) + [("<unk>", len(entity2idx))])

    req_corpus_idx_collection = defaultdict(list)
    times = 0
    for ridx, r in req_entity_num_df.iterrows():
        req_corpus_idx_collection["true_e"].append(list(map(lambda x: entity2idx[x] if entity2idx.get(x) else entity2idx["<unk>"],eval(r["true_e"]))))
        req_corpus_idx_collection["false_e"].append(list(map(lambda x: entity2idx[x] if entity2idx.get(x) else entity2idx["<unk>"] ,eval(r["false_e"]))))
        req_corpus_idx_collection["doc_pos"].append(list(map(lambda x: word2idx[x] if word2idx.get(x) else word2idx["<unk>"] ,eval(r["doc_pos"]))))

        if len(req_corpus_idx_collection["true_e"]) == 1000:
            temp_df = pd.DataFrame.from_dict(req_corpus_idx_collection)
            req_corpus_idx_collection = defaultdict(list)
            if times == 0:
                temp_df.to_csv("req_corpus_idx_df.csv", encoding="utf-8", header=True)
            else:
                with open("req_corpus_idx_df.csv", "a", encoding="utf-8") as f:
                    temp_df.to_csv(f, header=False)
            times += 1
            print("times :{}".format(times))
    print("corpus serlize end")

    entity_df = pd.DataFrame.from_csv(r"D:\Coding\python\KESM\data_process\entity_df.csv", encoding="utf-8")
    req_entity_idx_collection = defaultdict(list)
    times = 0
    for ridx, r in entity_df.iterrows():
        req_entity_idx_collection["entity"].append(entity2idx[r["entity"]] if entity2idx.get(r["entity"]) else entity2idx["<unk>"])
        req_entity_idx_collection["desp_token"].append(list(map(lambda x:  word2idx[x] if word2idx.get(x) else word2idx["<unk>"] ,eval(r["desp_token"]))))

        if len(req_entity_idx_collection["entity"]) == 1000:
            temp_df = pd.DataFrame.from_dict(req_entity_idx_collection)
            req_entity_idx_collection = defaultdict(list)
            if times == 0:
                temp_df.to_csv("entity_idx_df.csv", encoding="utf-8", header=True)
            else:
                with open("entity_idx_df.csv", "a", encoding="utf-8") as f:
                    temp_df.to_csv(f, header=False)
            times += 1
            print("times: {}".format(times))
    print("entity serlize end")
    with open("idx_dict.pkl", "wb") as f:
        pickle.dump({
            "word2idx": word2idx,
            "entity2idx": entity2idx
        }, f)

def process_ed_csv():
    temp_list = []
    df_chunk_size = int(1e4)
    times = 0
    df_dir = "entity_df.csv"

    def serialize(times, temp_list, final = False):
        if final:
            df = pd.DataFrame(temp_list, columns=["entity", "desp_token"])
            if times:
                with open(df_dir, "a", encoding="utf-8") as f:
                    df.to_csv(f, header=False)
            return

        if len(temp_list) >= df_chunk_size:
            df = pd.DataFrame(temp_list, columns=["entity", "desp_token"])
            if times:
                with open(df_dir, "a", encoding="utf-8") as f:
                    df.to_csv(f, header=False)
            else:
                df.to_csv(df_dir, header=True, encoding="utf-8")
            temp_list = []
            times += 1
            print("write times :{}".format(times))
        return times, temp_list

    entity_description_df = pd.read_csv("D:\Coding\python\KESM\WikiDataCrawler\ed.csv", encoding="utf-8")
    for row_idx, row in entity_description_df.iterrows():
        entity = row["entity"]
        description = row["description"]
        desp_token = clean_pos(description)
        temp_list.append([entity, desp_token])
        times, temp_list = serialize(times, temp_list)
    serialize(times, temp_list, final=True)

def sort_corpus_df_entity_have_num():
    entity_df = pd.read_csv("D:\Coding\python\KESM\data_process\entity_df.csv", encoding="utf-8", header=0)
    corpus_df = pd.read_csv("D:\Coding\python\KESM\data_process\corpus_df.csv", encoding="utf-8", header=0)
    all_entity_set = set(entity_df["entity"].tolist())
    print("all_entity_set num :{}".format(len(all_entity_set)))
    def count_intersection(x):
        return len(set(eval(x["true_e"]) + eval(x["false_e"])).intersection(all_entity_set))
    corpus_df["entity_intersect_count"] = corpus_df.apply(count_intersection, axis=1)
    print("col calculate end")
    corpus_df.to_csv("corpus_with_entity_num.csv", encoding="utf-8")
    print("df serialize end")


if __name__ == "__main__":
    process_ed_csv()
    process_corpos()
    sort_corpus_df_entity_have_num()
    entity_and_word_idx()
    idx_file()

        上面的过程进行了简单的数据清洗、编码。sort_corpus_df_entity_have_num()entity_and_word_idx() 两个函数计算了每个样本有实体描述的个数,为了显示提升模型的效果,这里采用实体个数超过10个的进行实验。


数据导出:

import tensorflow as tf
import numpy as np
from functools import reduce
import pandas as pd
import pause
import gc
import pickle

with open(r"D:\Coding\python\KESM\data_process\idx_dict.pkl", "rb") as f:
    idx_dict = pickle.load(f)
    word2idx = idx_dict["word2idx"]
    entity2idx = idx_dict["entity2idx"]

corpus_df = pd.DataFrame.from_csv(r"D:\Coding\python\KESM\data_process\req_corpus_idx_df.csv", encoding="utf-8", header=0)
test_basic_sample_num = int(1e4)

entity_df = pd.DataFrame.from_csv(r"D:\Coding\python\KESM\data_process\entity_idx_df.csv", encoding="utf-8", header=0)
entity_desp_dict = dict()
for _, r in entity_df.iterrows():
    entity_desp_dict[r["entity"]] = r["desp_token"]
del entity_df
gc.collect()

# epi: entity padding index
# dpi word padding idx
def data_generator(type = "train", batch_num = 128, max_kee_num = 10, desp_max_length = 50, doc_max_length = 500,
                   epi = len(entity2idx) - 1, dpi = len(word2idx) - 1):
    take_nums = 1e10
    take_num = 0
    global corpus_df
    if type == "train":
        corpus_df = corpus_df.iloc[test_basic_sample_num:, :]
    else:
        corpus_df = corpus_df.iloc[:test_basic_sample_num, :]

    input_kee_1 = np.full(shape=[batch_num, max_kee_num], fill_value=epi, dtype=np.int32)
    input_kee_desp_1 = np.full(shape=[batch_num, max_kee_num, desp_max_length], fill_value=dpi, dtype=np.int32)
    input_kee_2 = np.full(shape=[batch_num, max_kee_num], fill_value=epi, dtype=np.int32)
    input_kee_desp_2 = np.full(shape=[batch_num, max_kee_num, desp_max_length], fill_value=dpi, dtype=np.int32)
    input_doc = np.full(shape=[batch_num, doc_max_length], fill_value=dpi, dtype=np.int32)
    input_kee_mask = np.ones(shape=[batch_num], dtype=np.int32)

    start_idx = 0
    for idx, r in corpus_df.iterrows():
        doc_pos, true_e, false_e = map(eval ,[r["doc_pos"], r["true_e"], r["false_e"]])
        all_e = (true_e + false_e)[:max_kee_num]
        false_e = list(set(all_e).difference(set(true_e)))[:len(true_e)]
        if len(false_e) != len(true_e):
            continue

        # input_kee* part
        for ele_idx ,true_e_ele in enumerate(true_e):
            input_kee_1[start_idx][0] = true_e_ele
            entity_desp_str = entity_desp_dict.get(true_e_ele)
            input_kee_desp_1[start_idx][0] = eval(entity_desp_str)[:desp_max_length] + \
                                           [dpi] * (desp_max_length - len(eval(entity_desp_str))) \
                if entity_desp_str else [dpi] * desp_max_length

            for oidx ,other in enumerate(set(all_e).difference(set([true_e_ele]))):
                if oidx == max_kee_num - 2:
                    break
                input_kee_1[start_idx][oidx + 1] = other
                entity_desp_str = entity_desp_dict.get(other)
                input_kee_desp_1[start_idx][oidx + 1] = eval(entity_desp_str)[:desp_max_length] \
                                                      + [dpi] * (desp_max_length - len(eval(entity_desp_str))) \
                    if entity_desp_str else [dpi] * desp_max_length

            false_e_ele = false_e[ele_idx]
            input_kee_2[start_idx][0] = false_e_ele
            entity_desp_str = entity_desp_dict.get(false_e_ele)
            input_kee_desp_2[start_idx][0] = eval(entity_desp_str)[:desp_max_length] + \
                                           [dpi] * (desp_max_length - len(eval(entity_desp_str))) \
                if entity_desp_str else [dpi] * desp_max_length
            for oidx ,other in enumerate(set(all_e).difference(set([false_e_ele]))):
                if oidx == max_kee_num - 2:
                    break
                input_kee_2[start_idx][oidx + 1] = other
                entity_desp_str = entity_desp_dict.get(other)
                input_kee_desp_2[start_idx][oidx + 1] = eval(entity_desp_str)[:desp_max_length] \
                                                      + [dpi] * (desp_max_length - len(eval(entity_desp_str))) \
                    if entity_desp_str else [dpi] * desp_max_length

            # input doc part
            input_doc[start_idx] = doc_pos[:doc_max_length] + [dpi] * (doc_max_length - len(doc_pos))

            # y part
            #y[start_idx] = 1

            # input_kee_mask part
            input_kee_mask[start_idx] = len(all_e)

            start_idx += 1
            if start_idx  == batch_num:
                yield (input_kee_1, input_kee_desp_1, input_kee_2, input_kee_desp_2, input_doc, input_kee_mask)
                start_idx = 0
                take_num += 1
                if take_num == take_nums:
                    return

                input_kee_1 = np.full(shape=[batch_num, max_kee_num], fill_value=epi, dtype=np.int32)
                input_kee_desp_1 = np.full(shape=[batch_num, max_kee_num, desp_max_length], fill_value=dpi, dtype=np.int32)
                input_kee_2 = np.full(shape=[batch_num, max_kee_num], fill_value=epi, dtype=np.int32)
                input_kee_desp_2 = np.full(shape=[batch_num, max_kee_num, desp_max_length], fill_value=dpi, dtype=np.int32)
                input_doc = np.full(shape=[batch_num, doc_max_length], fill_value=dpi, dtype=np.int32)
                input_kee_mask = np.ones(shape=[batch_num], dtype=np.int32)

        从数据导出部分可以窥见模型输入数据的格式。由于最后损失为score对的形式(这并不是偶然的,可以参见的learning to rank模型,这里提一句,如果将配对训练打分的形式改成单个实体的简单二分类,结果是很差的,因为所有的实体不论salience与否都与原文有关,进行二分类无法突出序上的准确度),故在实体输入上也采用正负配对输入的格式,即

input_kee_1对应salience实体的信息,input_kee_desp_1是其对应的描述,input_kee_2为非salience实体(input_kee_desp_2)。

        这里input_kee_*第一维为batch维度,第二维第一个元素部分为目标实体(即用来与所有实体及文本算相似度的元素),实际这里相当于input_kee_1 input_kee_2第二个维度除了第一位元素都是相同的(懒得改网络结构了),从编程的角度累赘了,感兴趣的话可以改一下。input_doc为样本的文本输入,input_kee_mask为描述input_kee_*第二个维度实际长度的mask向量。

有了上述数据输入,如下定义网络结构:

# 为网络输入简单,将target kee 与 kee of doc拼接输入,
# 第一个为target 及其对应描述,后面的可以进行slice得到。
kernel_params_list = []
mu_array = np.arange(-0.9, 1.0, 0.1)
sigma = 0.1
for mu in mu_array:
    kernel_params_list.append((mu, sigma, mu, sigma))

class KESM(object):
    def __init__(self, desp_max_length = 50, doc_max_length = 500, entity_size = int(1e4),
                 vocab_size = int(1e5), entity_embedding_dim = 100, word_embedding_dim = 100,
                 cnn_filter_sizes = [3, 4, 5], num_filters = 3, batch_num = 128,
                 kernel_params_list = kernel_params_list):
        self.desp_max_length = desp_max_length
        self.doc_max_length = doc_max_length
        self.entity_size = entity_size
        self.vocab_size = vocab_size
        self.entity_embedding_dim = entity_embedding_dim
        self.word_embedding_dim = word_embedding_dim
        self.batch_num = batch_num
        self.kernel_params_list = kernel_params_list

        # nn params
        self.cnn_filter_sizes = cnn_filter_sizes
        self.num_filters = num_filters

        self.input_kee_1 = tf.placeholder(tf.int32, shape=[None, None])
        self.input_kee_desp_1 = tf.placeholder(tf.int32, shape=[None, None, desp_max_length])

        self.input_kee_2 = tf.placeholder(tf.int32, shape=[None, None])
        self.input_kee_desp_2 = tf.placeholder(tf.int32, shape=[None, None, desp_max_length])

        self.input_doc = tf.placeholder(tf.int32, shape=[None, doc_max_length])

        #input all kee num in int format
        self.input_kee_mask = tf.placeholder(tf.int32, [None])

        # same input tensors which will be init in the graph construct time.
        self.ve_before_reshape_1 = None
        self.ve_before_reshape_2 = None

        self.max_mask_num = None
        self.loss = None

        with tf.name_scope("entity_embeddings"):
            self.Entity_Embedding = tf.Variable(
                tf.random_normal(shape=[self.entity_size, self.entity_embedding_dim]),
                name="EE"
            )

        with tf.name_scope("word_embeddings"):
            self.Word_Embedding = tf.Variable(
                tf.random_normal(shape=[self.vocab_size, self.word_embedding_dim])
                ,name="WE"
            )

        # model construct
        self.opt_construct()

    def model_construct(self, input_kee_desp, input_kee):
        reshape_padding_ve = self.knowledge_enriched_embedding_layer(input_kee_desp, input_kee)
        doc_lookup = tf.nn.embedding_lookup(self.Word_Embedding ,self.input_doc)
        param_tuple_map = list(map(lambda inner_tuple: tuple([reshape_padding_ve, doc_lookup] + list(inner_tuple)) ,self.kernel_params_list))
        batch_kernel_list = reduce(lambda x, y: x + y,map(self.kernel_interaction_kernel, param_tuple_map))

        # [batch, total_kernel_dim]
        # fuse kee doc information into a vector
        KIM = tf.concat(batch_kernel_list, axis = -1)
        #f_score = tf.layers.dense(KIM, units=1, reuse=True)
        f_score = tf.layers.dense(KIM, units=1)
        return f_score

    def opt_construct(self):
        with tf.variable_scope("f_score_layer", reuse=tf.AUTO_REUSE):
            f_score_1 = self.model_construct(self.input_kee_desp_1, self.input_kee_1)
            f_score_2 = self.model_construct(self.input_kee_desp_2, self.input_kee_2)

        self.f_score_1 = f_score_1
        self.f_score_2 = f_score_2

        # f_score_1 indicate positive f_score_2 indicate negative
        self.loss = tf.reduce_mean(tf.nn.relu(1 - f_score_1 + f_score_2))
        self.train_op = tf.train.AdamOptimizer(0.001).minimize(self.loss)

    def cnn_layer(self, M):
        # Create a convolution + maxpool layer for each filter size
        sequence_length = int(M.get_shape()[1])
        pooled_outputs = []
        for i, filter_size in enumerate(self.cnn_filter_sizes):
            with tf.variable_scope("conv-maxpool-%s" % filter_size):
                # Convolution Layer
                filter_shape = [filter_size, self.word_embedding_dim, 1, self.num_filters]
                W = tf.get_variable(
                    shape=filter_shape,dtype=tf.float32,initializer=tf.initializers.random_normal(),
                    name="cnn_W_{}".format(filter_size)
                )
                b = tf.get_variable(
                    shape=[self.num_filters], dtype=tf.float32, initializer=tf.initializers.constant(1.0),
                    name= "cnn_b_{}".format(filter_size)
                )
                conv = tf.nn.conv2d(
                    M,
                    W,
                    strides=[1, 1, 1, 1],
                    padding="VALID",
                    name="conv")
                # Apply nonlinearity
                h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
                # Maxpooling over the outputs
                pooled = tf.nn.max_pool(
                    h,
                    ksize=[1, sequence_length - filter_size + 1, 1, 1],
                    strides=[1, 1, 1, 1],
                    padding='VALID',
                    name="pool")
                pooled_final_shape = pooled.get_shape()
                final_size = int(pooled_final_shape[-1]) * int(pooled_final_shape[-2]) * int(pooled_final_shape[-3])
                pooled_outputs.append(tf.reshape(pooled, [-1, final_size]))

        return tf.concat(pooled_outputs, -1, name="kee_desp_cnn_output")

    def knowledge_enriched_embedding_layer(self, input_kee_desp, input_kee):
        input_kee_desp_reshape = tf.reshape(input_kee_desp, [-1, self.desp_max_length])
        input_kee_reshape = tf.reshape(input_kee, [-1])

        with tf.name_scope("kee_desp_cnn_layer"):
            kee_desp_lookup = tf.nn.embedding_lookup(self.Word_Embedding, input_kee_desp_reshape)
            #[batch_hat, desp_len, embedding, 1]
            kee_desp_lookup_expd = tf.expand_dims(kee_desp_lookup, -1)
            #[batch_hat, cnn_flat_dim]
            kee_desp_cnn_output = self.cnn_layer(kee_desp_lookup_expd)

        #[batch_hat, entity_dim]
        kee_lookup = tf.nn.embedding_lookup(self.Entity_Embedding, input_kee_reshape)
        fuse_desp_entity = tf.concat([kee_desp_cnn_output, kee_lookup], axis=-1)
        ve_dim = self.word_embedding_dim

        ve_before_reshape = tf.layers.dense(fuse_desp_entity, units=ve_dim, name="ve_layer")
        self.ve_before_reshape = ve_before_reshape

        # slice entire sequence by generate bool sequences.
        sum_mask_num = tf.reduce_sum(self.input_kee_mask)
        cumsum_mask_num_second = tf.cumsum(self.input_kee_mask)
        cumsum_mask_num_first = tf.slice(tf.concat([tf.constant([0]), cumsum_mask_num_second], axis=0), [0], [self.batch_num])
        cumsum_seq_mask_first = tf.cast(tf.sequence_mask(cumsum_mask_num_first, maxlen=sum_mask_num), tf.int32)
        cumsum_seq_mask_second = tf.cast(tf.sequence_mask(cumsum_mask_num_second, maxlen=sum_mask_num), tf.int32)
        cumsum_seq_mask = tf.cast(cumsum_seq_mask_second - cumsum_seq_mask_first, tf.bool)
        max_mask_num = tf.reduce_max(self.input_kee_mask)

        self.max_mask_num = max_mask_num
        def reshape_padding_ve_func(cumsum_seq_mask_ele):
            head = tf.boolean_mask(self.ve_before_reshape, cumsum_seq_mask_ele)
            padding = tf.zeros(shape=[max_mask_num - tf.reduce_sum(tf.cast(cumsum_seq_mask_ele, tf.int32)), ve_dim])
            return tf.concat([head, padding], axis=0)

        # [batch, max_mask_num, fuse_dim]
        reshape_padding_ve = tf.map_fn(reshape_padding_ve_func, cumsum_seq_mask, dtype=tf.float32)
        return reshape_padding_ve

    def kernel_interaction_kernel(self, input):
        kee_input, doc_input, mu_k_e, sigma_k_e, mu_k_d, sigma_k_d = input
        # kee_input [batch, max_mask_num, fuse_dim]
        # doc_input [batch, doc_max_len, word_embedding_dim]

        kee_target_part = tf.nn.l2_normalize(tf.slice(kee_input, [0, 0, 0], [-1, 1, -1]), dim=1)
        kee_others_part = tf.nn.l2_normalize(tf.slice(kee_input, [0, 1, 0], [-1, -1, -1]), dim=1)
        kee_target_part_tiled = tf.tile(kee_target_part, [1, self.max_mask_num - 1 ,1])

        # [batch, max_mask_num - 1]
        kee_cos_part = tf.reduce_sum(kee_target_part_tiled * kee_others_part, axis=-1)
        exp_kee_cos_part = tf.exp(-1 * (kee_cos_part - mu_k_e) * (kee_cos_part - mu_k_e) / sigma_k_e)
        kee_valid_part_mask = tf.cast(tf.sequence_mask(self.input_kee_mask - 1, maxlen=self.max_mask_num - 1), tf.float32)
        #[batch]
        kee_batch_kernel = tf.reduce_sum(exp_kee_cos_part * kee_valid_part_mask, axis=-1)

        doc_max_len = int(doc_input.get_shape()[1])
        doc_target_part_tiled = tf.tile(kee_target_part, [1, doc_max_len ,1])
        doc_cos_part = tf.reduce_sum(doc_target_part_tiled * doc_input, axis=-1)
        exp_doc_cos_part = tf.exp(-1 * (doc_cos_part - mu_k_d) * (doc_cos_part - mu_k_d) / sigma_k_d)
        #[batch]
        doc_batch_kernel = tf.reduce_sum(exp_doc_cos_part, axis=-1)

        return [tf.expand_dims(kee_batch_kernel, -1), tf.expand_dims(doc_batch_kernel, -1)]

    @staticmethod
    def train():
        import pause
        batch_num = 128
        model = KESM(batch_num=batch_num, vocab_size=len(word2idx), entity_size=len(entity2idx))

        times = 0
        train_gen = data_generator(batch_num = batch_num)
        test_gen = data_generator(type="test" ,batch_num = batch_num)
        epoch = 0

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            while True:
                try:
                    input_kee_1, input_kee_desp_1, input_kee_2, input_kee_desp_2, input_doc, input_kee_mask = train_gen.__next__()
                except:
                    print("epoch {} end".format(epoch))
                    train_gen = data_generator(batch_num = batch_num)
                    epoch += 1

                _  ,train_loss, = sess.run([model.train_op ,model.loss], feed_dict={
                    model.input_kee_1: input_kee_1,
                    model.input_kee_desp_1: input_kee_desp_1,
                    model.input_kee_2: input_kee_2,
                    model.input_kee_desp_2: input_kee_desp_2,
                    model.input_doc: input_doc,
                    model.input_kee_mask: input_kee_mask
                })
                print("train loss :{}".format(train_loss))
                if times % 10 == 0:
                    try:
                        input_kee_1, input_kee_desp_1, input_kee_2, input_kee_desp_2, input_doc, input_kee_mask = test_gen.__next__()
                    except:
                        print("epoch {} end".format(epoch))
                        test_gen = data_generator(type="test" ,batch_num = batch_num)
                        epoch += 1

                    test_loss, f_score_1, f_score_2 = sess.run([model.loss, model.f_score_1, model.f_score_2], feed_dict={
                        model.input_kee_1: input_kee_1,
                        model.input_kee_desp_1: input_kee_desp_1,
                        model.input_kee_2: input_kee_2,
                        model.input_kee_desp_2: input_kee_desp_2,
                        model.input_doc: input_doc,
                        model.input_kee_mask: input_kee_mask
                    })

                    acc = np.array(f_score_1 > f_score_2, dtype=np.float32).mean()
                    print("test loss :{} acc:{}".format(test_loss, acc))

                times += 1



if __name__ == "__main__":
    KESM.train()

        这里rbf核的超参数与原文的设定一致,并对于测试集计算了acc看精度(训练集训练时拟合较好,不记录),这里的收敛精度定义为,f_score_1 > f_score_2的比例,因为从输入角度看应有这个结果成立。

        由于之前数据处理的特征,及挑选的都是有大量实体描述的样本,收敛程度较好,下面是训练先期记录:

train loss :32.136802673339844
test loss :29.226028442382812 acc:0.0234375
train loss :31.882282257080078
train loss :30.8996639251709
train loss :28.286611557006836
train loss :31.86393165588379
train loss :27.266090393066406
train loss :20.51707649230957
train loss :19.864543914794922
train loss :24.485538482666016
train loss :21.631492614746094
train loss :20.699556350708008
test loss :20.515953063964844 acc:0.078125
train loss :21.719375610351562
train loss :18.175708770751953
train loss :18.79627227783203
train loss :15.973711967468262
train loss :22.51378059387207
train loss :16.266401290893555
train loss :16.994487762451172
train loss :19.631580352783203
train loss :19.612144470214844
train loss :16.92822265625
test loss :17.485687255859375 acc:0.0546875
train loss :14.07847785949707
train loss :17.884946823120117
train loss :18.433307647705078
train loss :15.116429328918457
train loss :14.5866117477417
train loss :12.568949699401855
train loss :16.726009368896484
train loss :16.02698516845703
train loss :12.45231819152832
train loss :18.996814727783203
test loss :14.900018692016602 acc:0.078125
train loss :12.213709831237793
train loss :14.977048873901367
train loss :14.515562057495117
train loss :16.51100730895996
train loss :19.332717895507812
train loss :14.827428817749023
train loss :16.82861328125
train loss :9.731084823608398
train loss :10.145099639892578
train loss :10.770788192749023
test loss :13.10962200164795 acc:0.21875
train loss :10.626840591430664
train loss :10.669485092163086
train loss :9.882623672485352
train loss :14.196492195129395
train loss :12.102956771850586
train loss :10.761802673339844
train loss :6.499850273132324
train loss :9.159351348876953
train loss :6.764830112457275
train loss :9.540651321411133
test loss :7.350927829742432 acc:0.375
train loss :11.662117004394531
train loss :10.092960357666016
train loss :8.116621017456055
train loss :7.211489677429199
train loss :6.5933051109313965
train loss :8.281878471374512
train loss :9.252912521362305
train loss :8.742061614990234
train loss :11.99124526977539
train loss :7.668788909912109
test loss :5.253876686096191 acc:0.46875
train loss :10.226776123046875
train loss :10.507610321044922
train loss :10.862887382507324
train loss :6.5200276374816895
train loss :6.80426025390625
train loss :3.6178395748138428
train loss :5.657415866851807
train loss :5.388607978820801
train loss :6.250310897827148
train loss :5.500130653381348
test loss :8.942505836486816 acc:0.6328125
train loss :5.482635498046875
train loss :5.60318660736084
train loss :6.55861234664917
train loss :5.086772918701172
train loss :3.8880820274353027
train loss :8.441230773925781
train loss :7.050711631774902
train loss :6.250699043273926
train loss :5.3991241455078125
train loss :2.348048210144043
test loss :6.638607978820801 acc:0.7734375
train loss :4.900678634643555
train loss :3.669212579727173
train loss :5.357037544250488
train loss :3.8981685638427734
train loss :4.366262435913086
train loss :1.9637587070465088
train loss :3.004202365875244
train loss :4.238396167755127
train loss :4.1686506271362305
train loss :6.312758445739746
test loss :2.175901412963867 acc:0.921875
train loss :2.532480478286743
train loss :2.8864879608154297
train loss :2.166124105453491
train loss :2.396486282348633
train loss :3.142045021057129
train loss :1.9758951663970947
train loss :4.525737762451172
train loss :1.556884527206421
train loss :2.350480794906616
train loss :1.3652114868164062
test loss :1.55498206615448 acc:0.9140625
train loss :1.553002119064331
train loss :1.5114221572875977
train loss :1.7723044157028198
train loss :2.9699597358703613
train loss :1.4816454648971558
train loss :1.3616077899932861
train loss :1.8693952560424805
train loss :2.3389508724212646
train loss :2.7014565467834473
train loss :1.9040240049362183
test loss :2.8285391330718994 acc:0.9140625
train loss :2.5975096225738525
train loss :1.7298221588134766
train loss :2.5001325607299805
train loss :2.3010199069976807
train loss :0.4678676426410675
train loss :0.6357259750366211
train loss :0.695709764957428
train loss :2.253744125366211
train loss :1.4345471858978271
train loss :1.1282814741134644
test loss :0.5373156070709229 acc:0.953125
train loss :0.311814546585083
train loss :1.4651544094085693
train loss :0.8269515037536621
train loss :0.5220216512680054
train loss :0.1606338918209076
train loss :0.1128043532371521
train loss :0.8103121519088745
train loss :0.4834401309490204
train loss :1.2986313104629517
train loss :0.4324950575828552
test loss :0.2764070928096771 acc:0.9453125
train loss :0.18162190914154053
train loss :0.47013553977012634
train loss :0.34739431738853455
train loss :0.3420702815055847
train loss :1.5287249088287354
train loss :0.9801853895187378
train loss :0.6511420607566833
train loss :0.6557278633117676
train loss :0.4798915982246399
train loss :0.5535762906074524
test loss :0.07494109869003296 acc:0.984375
train loss :0.511849045753479
train loss :0.8412278294563293
train loss :0.4313739538192749
train loss :0.449565052986145
train loss :0.252207487821579
train loss :0.30550551414489746
train loss :0.4194437265396118
train loss :0.5475836396217346
train loss :0.6809948086738586
train loss :0.21016159653663635
test loss :0.10047291219234467 acc:0.984375
train loss :0.5512081980705261
train loss :0.46260714530944824
train loss :0.1845911145210266
train loss :1.8909573554992676e-05
train loss :0.00864720344543457
train loss :0.003664463758468628
train loss :0.3173838257789612
train loss :0.3461763858795166
train loss :0.0414513535797596
train loss :0.02685767412185669
test loss :0.3495218753814697 acc:0.9765625

如果感兴趣,可以尝试把搜索引擎部分做了。









猜你喜欢

转载自blog.csdn.net/sinat_30665603/article/details/80295784