NLP实践——知识图谱问答模型FiD

NLP实践——知识图谱问答模型FiD

0. 简介

好久没有更新了,今天介绍一个知识图谱问答(KBQA)模型,在此之前我一直在用huggingface的Pipeline中提供的QA模型,非常方便但是准确性不是特别好。今天介绍的这个模型是Facebook在2021年就已经提出来的FiD(Fusion-in-Decoder),发表在ACL上。

论文地址: https://aclanthology.org/2021.eacl-main.74.pdf
项目地址:https://github.com/facebookresearch/FiD

其实我原本是想看EMNLP2022中的一篇文章,也已经开源。这个项目叫Grape,其基本思想实在FiD的基础上,采用两个T5 Encoder,并且在解码之前利用query和候选文本中的实体,构建GNN,在节点上做了Attention以增强Encoder的表征。

论文地址: https://arxiv.org/pdf/2210.02933.pdf

但是grape的这个项目我在实验的时候,遇到了一点环境配置上的问题,作者采用了一个比较冷门的dgl版本,这个版本在Linux_x86_64系统上没有官方编译过,于是我尝试自己编译,又遇到了一堆cmake和gcc版本的问题,于是放弃尝试。但顺着Grape的论文,找到了FiD这一项目。

1. 模型结构

所谓KBQA,也就是在问答模型的基础上,除了给定原文的信息之外,还考虑知识库中其他的预料信息。这个模型的原理很简单,就是一个生成模型,加上召回任务。

也就是先利用一个召回模型,在知识库中召回若干与给定的原文相关的文本,然后再将问题分别与原文以及相关文本进行拼接,拼接后的结果分别进行编码,再将编码的结果进行concat,最终把concat的结果给到Decoder,由Decoder生成答案。

模型结构
采用的基础模型是T5,分别在两个数据集NaturalQuestions和TriviaQA上进行了训练,数据和训练好的模型均可在git上找到。

2. 召回

召回这部分其实没有什么东西,在官方的git中,就是采用bert-base做了一下编码,我没有跟着它的做法,感兴趣的同学可以自行阅读retrieval相关的py文件。

这里我是觉得自己编码更方便一些,可以直接采用Sentence transformer的预训练模型,或者你自己训练的什么编码模型,另外做成faiss或者milvus索引的话,效率还会高不少。关于Sentence transformer,在好久之前的这篇博客中也介绍过。

3. 问答

虽然这个模型是KBQA模型,但是git上似乎也没有直接给出Fusion的那部分代码。这里我们不妨自己先写一个预测方法,利用它训练好的模型来实现QA的功能。

由于它本身其实就是一个T5模型,所以只要你对transformers模块比较熟悉的话,可以很轻易的写出预测方法。

首先我们加载一下模型和tokenizer:

from transformers import AutoTokenizer
from src.model import FiDT5  # 注意引用时的目录,引不进来就直接把这个类复制过来

# 从git上下载你想要尝试的模型,比如nq,把文件都放在一个目录里,然后用from_pretrained读取它
model = FiDT5.from_pretrained('your_path_to_Fid_model/nq_reader_base/')
tokenizer = AutoTokenizer.from_pretrained('t5-large')  # 联网下载,或提前下载好放在本地目录

# 然后eval一下,关掉dropout和BN,如果你比较叛逆,不关也是可以的
model.eval()

接下来我们写一个简单的预测方法,就可以实现QA了。

def predict(model, tokenizer, question, title, context, device='cpu'):
    """
    预测
    :param model: T5模型
    :param tokenizer: 分词器
    :param question: 问题
    :param title: 标题,没有的话可以给空字符
    :param context: 正文
    :param device: 在cpu还是cuda上执行
    ---------------
    ver: 2023-01-12
    by: changhongyu
    """
    if device.startswith('cuda'):
        model.to(device)
    combined_text = "question: " + question + "title: " + title + "context: " + context
    inputs = tokenizer(combined_text, max_length=1024, return_tensors='pt')
    test_outputs = model.generate(
        input_ids=inputs['input_ids'].unsqueeze(0).to(device),
        attention_mask=inputs['attention_mask'].unsqueeze(0).to(device),
        max_length=50,
    )
    answer = tokenizer.decode(test_outputs[0])
    
    return answer

来测试一下效果:

predict(
    model, 
    tokenizer,
    "Who is Russia's new commander",
    "Russia Ukraine War Live Updates: Russia changes commanders again in Ukraine",
    """09:20 (IST) Jan 12 Ukrainian military analyst Oleh Zhdanov said the situation in Soledar was "approaching that of critical" "The Ukrainian armed forces are holding their positions. About one half of the town is under our control. Fierce fighting is going on near the town centre," he said on YouTube.However, Zhdanov told Ukrainian television that if Russian forces seized Soledar or nearby Bakhmut it would be more a political victory than military. 09:18 (IST) Jan 12 Russian private military firm Wagner Group said its capture of the salt mining town Soledar in eastern Ukraine was complete- a claim denied by Ukraine 09:08 (IST) Jan 12 Russia changes commanders again in Ukraine Moscow named a new commander for its invasion of Ukraine. Russian Defence Minister Sergei Shoigu on Wednesday appointed Chief of the General Staff Valery Gerasimov as overall commander for what Moscow calls its "special military operation" in Ukraine, now in its 11th month.The change effectively demoted General Sergei Surovikin, who was appointed only in October to lead the invasion and oversaw heavy attacks on Ukraine's energy infrastructure. 06:40 (IST) Jan 12 Russia, Ukraine agree new prisoner swap in Turkey Russia and Ukraine on Wednesday agreed a new prisoner swap during rare talks in Turkey during which they also discussed the creation of a "humanitarian corridor" in the war zone. Ukraine's human rights ombudsman Dmytro Lubinets met his Russian counterpart Tatyana Moskalkova on the sidelines of an international conference in Ankara attended by Turkish President Recep Tayyip Erdogan. 06:39 (IST) Jan 12 President Volodymyr Zelenskyy urged NATO on Wednesday to do more than just promise Ukraine its door is open at a July summit, saying Kyiv needs "powerful steps" as it tries to join the military alliance. 06:39 (IST) Jan 12 Russian forces shelled 13 settlements in and around Kharkiv region largely returned to Ukrainian hands in September and October, the Ukrainian military said. 06:38 (IST) Jan 12 Russia's war on Ukraine latest: Russia puts top general in charge of invasion Russia ordered its top general on Wednesday to take charge of its faltering invasion of Ukraine in the biggest shake-up yet of its malfunctioning military command structure after months of battlefield setbacks. 06:37 (IST) Jan 12 Zelenskyy says Russian war won't become WWIII Ukraine will stop Russian aggression and the conflict won't turn into World War III, President Volodymyr Zelenskiy said as his forces battled to keep control of Soledar and Bakhmut in the eastern Donetsk region. The Kremlin had positioned the most experienced units from the Wagner military-contracting company near Soledar, according to Ukrainian operational command spokesman Serhiy Cherevatyi."""
)

模型给出的回答符合预期:

Valery Gerasimov

4. 结合知识的问答

官方的代码中好像没有给出这部分内容,所以我根据论文的思路简单实现了一下,简而言之就是在召回之后,将目标文档的编码结果与召回的参考文档的编码结果进行拼接,然后再统一进行解码即可。

def predict_with_reference(model, tokenizer, question, title, context, reference_title, reference_context, device='cpu'):
    """
    预测
    :param model: T5模型
    :param tokenizer: 分词器
    :param question: 问题
    :param title: 标题,没有的话可以给空字符
    :param context: 正文
    :param reference_title: 召回文本的标题
    :param reference_context: 召回文本的正文
    :param device: 在cpu还是cuda上执行
    ---------------
    ver: 2023-01-12
    by: changhongyu
    """
    if device.startswith('cuda'):
        model.to(device)
    combined_text = "question: " + question + "title: " + title + "context: " + context
    combined_refer = "question: " + question + "title: " + reference_title + "context: " + reference_context
    query_inputs = tokenizer(combined_text, max_length=1024, return_tensors='pt')
    refer_inputs = tokenizer(combined_refer, max_length=1024, return_tensors='pt')
    test_outputs = model.generate(
        input_ids=torch.cat([query_inputs['input_ids'].unsqueeze(0), refer_inputs['input_ids'].unsqueeze(0)], dim=2).to(device),
        attention_mask=torch.cat([query_inputs['attention_mask'].unsqueeze(0), refer_inputs['attention_mask'].unsqueeze(0)], dim=2).to(device),
        max_length=50,
    )
    answer = tokenizer.decode(test_outputs[0])
    
    return answer

然后来测试一下效果:

假设我们有一篇地震相关的新闻:

text = """The death toll in Syria and Turkey from the earthquake has passed 12,000, with the number of injured exceeding 100,000, while hundreds of thousands have been displaced. In Turkey, at least 9,000 have been killed and nearly 60,000 people have been injured, authorities said on Wednesday. The death toll in Syria stands at more than 3,000, according to the Syrian Observatory for Human Rights, while Syrian state media reported more than 298,000 people have been displaced."""

以及在知识库里召回的一篇叙相关的介绍:

reference = """Syria (Arabic: سوريا‎, romanized: Sūriyā), officially the Syrian Arab Republic (Arabic: الجمهورية العربية السورية‎, romanized: al-Jumhūrīyah al-ʻArabīyah as-Sūrīyah), is a country in Western Asia, bordering Lebanon to the southwest, the Mediterranean Sea to the west, Turkey to the north, Iraq to the east, Jordan to the south, and Israel to the southwest. A country of fertile plains, high mountains, and deserts, Syria is home to diverse ethnic and religious groups, including Syrian Arabs, Kurds, Turkemens, Assyrians, Armenians, Circassians, Mandeans and Greeks. Religious groups include Sunnis, Christians, Alawites, Druze, Isma'ilis, Mandeans, Shiites, Salafis, Yazidis, and Jews. Arabs are the largest ethnic group, and Sunnis the largest religious group."""

然后进行问答:

predict_with_reference(
    model, 
    tokenizer,
    question="where is Syria.",
    title="Earthquake death toll exceeds 12,000 as Turkey, Syria seek help.",
    context=text,
    reference_title="Syria",
    reference_context=reference,
)

模型给出的回答是:

'Western Asia'

答案也是符合预期的。

如果是召回多篇文档,理论上将predict_with_reference这个方法的reference都改成list,然后再拼接的时候把结果组合起来就可以了,感兴趣的同学可以自己尝试一下。

以上就是本文的全部内容了,在ChatGPT时代下,KBQA这个话题似乎有点“过时”了,但是这对于练习NLP基础任务和理解attention的运作还是很有帮助的。如果这篇文章对你有帮助,欢迎一键三连加关注,也欢迎评论区或私信交流,我们下期再见。

猜你喜欢

转载自blog.csdn.net/weixin_44826203/article/details/128939386