NLP工具——自制zero-shot事件抽取器

0. 简介

在事件抽取任务中,数据的获取是一件非常关键工作,由于数据标注的成本较高,高价值数据获取较难,所以few-shot和zero-shot的任务一直是事件抽取领域研究的一个重点。

今天介绍的这个工具是我利用stanza句法分析写的,写出来已经有很长的时间了。介绍这个工具的目的不是说它也是一个针对零样本学习或是小样本学习的研究,它就是一个简单的应用工具,完全是基于规则写的,没有任何技术含量,它的有效性也完全来自于stanza的句法分析功能。

但是在实际应用中,我们可以利用这些句法结构,采用弱监督的策略去生成一批silver数据,然后再投入人工标注,在silver数据的基础上形成gold数据,这样一来就可以有效地减少人工标注的成本。

对于stanza工具不是很熟悉的同学,可以看一下我之前的一篇博客(实际上,本文所介绍的工具也是同一时期实现的):
NLP工具——Stanza依存关系含义详解

1. 抽取全部潜在的事件

这个抽取其实就是在拼接重组句子的整个主谓宾定状补的结构,废话不多说,直接上代码;

class DepEventExtractor:
    """
    【依存关系事件抽取】
    约定变量命名规则:
    dep_res: self.get_dep_res()的结果, 具有结构:
        [{'sent': xxx, 'sent_id': 1, 'dep': []}, {'sent': xxx, 'sent_id': 2, 'dep': []}]
        其中dep在_extract系列函数中, 记作dep_tree, 其中的元素的id, 即第几个token, 记作cur_root_id
    ---------------
    Upgrade(v2):
    1. 考虑被动语态
    2. 考虑弱动词占位,强动词后移的情况
    3. 考虑从句结构
    ---------------
    Upgrade(v3):
    1. 加入词性判断
    2. 加入时态词和否定词
    3. 增加无连词并列结构
    ---------------
    Upgrade(v4):
    1. 对状语进一步区分时间地点
    ---------------
    ver: 2021-12-07
    by: changhongyu
    """
    def __init__(self, stanza_model):
        """
        :param stanza_model: stanza.Pipeline: 实例化的stanza模型
        """
        self.stanza_model = stanza_model
        self.debug_pat_name = 'xxx'
        
    @staticmethod
    def get_depparse_res(doc_res):
        final_res = []
        for i, item in enumerate(doc_res.sentences):
            sent = item.text
            dep_array = []
            for word in item.dependencies:
                # print(word)
                if not word[2].head:
                    head_id = -1
                else:
                    head_id = word[2].head - 1
                dep_array.append({
    
    "head": word[0].text, "dep_relation": word[1], "words": word[2].text,
                                  "id": word[2].id - 1, "head_id": head_id, "upos": word[2].upos, 
                                  "lemma": word[2].lemma, "ner": item.tokens[word[2].id - 1].ner})
            final_res.append({
    
    "sent": sent, "sent_id": i + 1, "dep": dep_array})
            
        return final_res
    
    def get_dep_res(self, text):
        """
        获取依存分析结果
        """
        doc_res = self.stanza_model(text)
        dep_res = self.get_depparse_res(doc_res)
        
        return dep_res
        
    @staticmethod
    def _extract_trigger(dep_tree):
        """
        抽取ROOT -root-> trigger
        """
        for i, dep in enumerate(dep_tree):
            if dep['head'] == 'ROOT':
                dep['token_start'] = i  # 在原句中的第几个token
                dep['token_end'] = i
                return i, dep
            
        return None, None
            
    def _extract_pattern(self, dep_tree, pattern_name, cur_root_id, reverse=False):
        """
        抽取某种结构
        如果多个主语,后出现的对nsubj会有conj依存
        :param reverse: bool: 如果reverse, 则主语宾语反向抽取
        """
        patterns = []
        if not reverse:
            id_to_extract = 'head_id'
        else:
            id_to_extract = 'id'
        # root_head = dep_tree[cur_root_id]['dep'].split()[0]  # 因为dep的内容会增加,所以取0
            
        for i, dep in enumerate(dep_tree):
            if pattern_name.endswith('*'):
                case = dep['dep_relation'].startswith(pattern_name[:-1]) and dep[id_to_extract] == cur_root_id
            else:
                case = dep['dep_relation'] == pattern_name and dep[id_to_extract] == cur_root_id
                
            if pattern_name == self.debug_pat_name:
                # 调试用
                print(dep['dep_relation'], pattern_name)
                print(dep['head_id'], cur_root_id)
                print('---')
            # if dep['dep_relation'] == pattern_name and dep['head'] == root_head:
            if case:
                dep['token_start'] = i
                dep['token_end'] = i
                patterns.append([i, dep])
            
        return patterns
    
    def _fill_trigger(self, dep_tree, trigger_id, trigger):
        """
        补全触发词
        """
        trigger_x_patterns = self._extract_pattern(dep_tree, 'compound*', trigger_id)
        if len(trigger_x_patterns) > 1:
            print("Warning: More than 1 nummod pattern occurred at trigger: ", trigger)
        for trigger_x_id, trigger_x in trigger_x_patterns:
            trigger['token_start'] = min(trigger['token_start'], trigger_x['token_start'])
            trigger['token_end'] = max(trigger['token_end'], trigger_x['token_end'])
            trigger['words'] = ''.join(tok + ' ' for tok in self.tokens[trigger['token_start']: trigger['token_end']+1])[: -1]
        # 补齐辅助词
        trigger = self._fill_with_aux(dep_tree, trigger_id, trigger)
            
        return trigger
    
    def _fill_with_flat(self, dep_tree, node_id, node):
        """
        补齐扁平多词并列结构
        调用: self._extract_pattern()
        Example: Donald --> Doanld J. Trump
        """
        flat_patterns = self._extract_pattern(dep_tree, 'flat', node_id)
        for flat_pat_id, flat_pat in flat_patterns:
            node['words'] += ' '
            node['words'] += flat_pat['words']
            node['token_end'] = flat_pat_id
            
        return node
    
    def _fill_with_compound(self, dep_tree, node_id, node):
        """
        补齐名词并列结构,只保留紧邻的并列结构
        compound分为单一主从compound和连续compound
        调用: self._extract_pattern()
        Example: cream --> ice cream
        """
        compound_patterns = self._extract_pattern(dep_tree, 'compound', node_id, reverse=False)
        for compound_pat_id, compound_pat in compound_patterns:
            # 对并列结构补齐并列结构
            compound_pat = self._fill_with_compound(dep_tree, compound_pat_id, compound_pat)
            node['token_start'] = min(node['token_start'], compound_pat['token_start'])
            node['token_end'] = max(node['token_end'], compound_pat['token_end'])
            node['words'] = ''.join(tok + ' ' for tok in self.tokens[node['token_start']: node['token_end']+1])[: -1]
        
        return node
    
    def _fill_with_amod(self, dep_tree, node_id, node):
        """
        补齐形容词
        调用: self._extract_pattern()
        Example: apple --> big red apple
        """
        amod_patterns = self._extract_pattern(dep_tree, 'amod', node_id)
        for amod_pat_id, amod_pat in amod_patterns:
            # 对修饰语补全
            amod_pat = self._fill_a_node(dep_tree, amod_pat_id, amod_pat)
            node['token_start'] = min(node['token_start'], amod_pat['token_start'])
            node['token_end'] = max(node['token_end'], amod_pat['token_end'])
            node['words'] = ''.join(tok + ' ' for tok in self.tokens[node['token_start']: node['token_end']+1])[: -1]
            
        return node
    
    def _fill_with_nummod(self, dep_tree, node_id, node):
        """
        补齐数字修饰语
        数字修饰语与节点紧邻,只有一个token,且不会与扁平结构同时出现
        调用: self._extract_pattern()
        Example: dollars --> forty dollars
        """
        nummod_patterns = self._extract_pattern(dep_tree, 'nummod', node_id)
        if len(nummod_patterns) > 1:
            print("Warning: More than 1 nummod pattern occurred at node: ", node)
        for nummod_pat_id, nummod_pat in nummod_patterns:
            node['token_start'] = min(node['token_start'], nummod_pat['token_start'])
            node['token_end'] = max(node['token_end'], nummod_pat['token_end'])
            node['words'] = ''.join(tok + ' ' for tok in self.tokens[node['token_start']: node['token_end']+1])[: -1]
            
        return node
    
    def _fill_with_det(self, dep_tree, node_id, node):
        """
        补齐限定词
        限定词包括冠词和疑问代词,不会与数字修饰或扁平结构同时出现
        调用: self._extract_pattern()
        Example: apple --> an apple
        """
        det_patterns = self._extract_pattern(dep_tree, 'det', node_id)
        if len(det_patterns) > 1:
            print("Warning: More than 1 det pattern occurred at node: ", node)
        for det_pat_id, det_pat in det_patterns:
            node['words'] = det_pat['words'] + ' ' + node['words']
            node['token_start'] = det_pat_id
            
        return node
    
    def _fill_with_nmod(self, dep_tree, node_id, node):
        """
        补全名词修饰结构
        直接把span拓展到修饰结构的结尾
        调用: self._extract_pattern()
              self._fill_a_node()
        Example: apple --> a couple of apples
        """
        nmod_patterns = self._extract_pattern(dep_tree, 'nmod*', node_id)
        if len(nmod_patterns) > 1:
            print("Warning: More than 1 nmod pattern occurred at node: ", node)
        for nmod_pat_id, nmod_pat in nmod_patterns:
            # 对修饰语补全
            nmod_pat = self._fill_a_node(dep_tree, nmod_pat_id, nmod_pat)
            node['token_start'] = min(node['token_start'], nmod_pat['token_start'])
            node['token_end'] = max(node['token_end'], nmod_pat['token_end'])
            node['words'] = ''.join(tok + ' ' for tok in self.tokens[node['token_start']: node['token_end']+1])[: -1]
            
        return node
    
    def _fill_with_case(self, dep_tree, node_id, node):
        """
        为状语补齐介词
        Example: last week --> during last week
        """
        case_patterns = self._extract_pattern(dep_tree, 'case', node_id)
        if len(case_patterns) > 1:
            print("Warning: More than 1 case pattern occurred at node: ", node)
        for case_pat_id, case_pat in case_patterns:
            node['words'] = case_pat['words'] + ' ' + node['words']
            node['token_start'] = case_pat_id
            
        return node
    
    def _fill_with_aux(self, dep_tree, node_id, node):
        """
        为动词补齐时态辅助词
        Example: go --> will go
        """
        aux_patterns = self._extract_pattern(dep_tree, 'aux', node_id)
        if len(aux_patterns) > 1:
            print("Warning: More than 1 aux pattern occurred at node: ", node)
        for aux_pat_id, aux_pat in aux_patterns:
            node['words'] = aux_pat['words'] + ' ' + node['words']
            node['token_start'] = aux_pat_id
            
        return node
    
    def _fill_a_node(self, dep_tree, node_id, node, is_obl=False):
        """
        对一个节点(一般为名词节点)进行补齐
        调用: self._fill_with_flat()
              self._fill_with_nummod()
              self._fill_with_det()
        """
        # 补齐扁平多词并列结构
        node = self._fill_with_flat(dep_tree, node_id, node)
        # 补齐名词并列结构
        node = self._fill_with_compound(dep_tree, node_id, node)
        # 补齐形容词
        node = self._fill_with_amod(dep_tree, node_id, node)
        # 补齐数字修饰语
        node = self._fill_with_nummod(dep_tree, node_id, node)
        # 补齐限定词
        node = self._fill_with_det(dep_tree, node_id, node)
        # 补齐名词修饰结构
        node = self._fill_with_nmod(dep_tree, node_id, node)
        if is_obl:
            node = self._fill_with_case(dep_tree, node_id, node)
        
        return node
    
    def _get_conj_patterns(self, dep_tree, node_id):
        """
        获取某个节点的并列结构
        调用: self._extract_pattern()
              self._fill_a_node()
        """
        conj_patterns = self._extract_pattern(dep_tree, 'conj', node_id)
        for conj_pat_id, conj_pat in conj_patterns:
            # 补齐节点
            conj_pat = self._fill_a_node(dep_tree, conj_pat_id, conj_pat)
        
        return conj_patterns
    
    def _adjust_obl_type_with_ner(self, node):
        """
        使用ner的结果对状语类型进行调整,调整为时间状语、地点状语
        :param node: dict: {'head', 'dep_relations', 'words', 'id', 'head_id', 'upos', 'lemma', 'ner'}
        :return type_: str: 'obl', 'obl:tmod', 'obl:lmod'
        """
        if 'DATE' in node['ner']:
            return 'obl:tmod'
        elif 'GPE' in node['ner']:
            return 'obl:lmod'
        elif 'PERSON' in node['ner']:
            return 'obl:pmod'
        else:
            return 'obl'
    
    def _get_core_arguments(self, dep_tree, argument_type, trigger_id):
        """
        获取一个事件的核心论元
        调用: self._extract_pattern()
              self._fill_a_node()
              self._get_conj_patterns()
        Notes: 尽管某一个角色的核心论元可能有多个, 但是句法结构上与上层节点相连的只有一个
        ---------------
        :param dep_tree:
        :param argument_type: 获取nsubj, obj, or iobj
        :param trigger_id: 该事件的触发词位置
        :return argument: 直接与上层节点相连的核心论元
        :return conj_arguments: 与argument连词并列的其他核心论元list
        """
        assert argument_type in ['nsubj', 'obj', 'iobj', 'nsubj:pass']
        arguments = self._extract_pattern(dep_tree, argument_type, trigger_id)
        if len(arguments):
            argument_id, argument = arguments[0]
            # 补齐节点
            argument = self._fill_a_node(dep_tree, argument_id, argument)
            
            # 获取连词结构
            conj_arguments = self._get_conj_patterns(dep_tree, argument_id)          
            return argument, conj_arguments
        
        return None, None
    
    def _get_none_core_arguments(self, dep_tree, argument_type, trigger_id):
        """
        获取非核心论元
        非核心论元允许同时存在多个
        """
        assert argument_type in ['obl', 'obl:tmod']
        arguments = self._extract_pattern(dep_tree, argument_type, trigger_id)
        all_arguments = []
        for argument_id, argument in arguments:
            # 补齐节点
            argument = self._fill_a_node(dep_tree, argument_id, argument, is_obl=True)
            
            # 获取连词结构
            conj_arguments = [item[1] for item in self._get_conj_patterns(dep_tree, argument_id)]
            
            all_arguments.append(argument)
            all_arguments.extend(conj_arguments)
        
        return all_arguments
    
    def _extract_event(self, dep_tree, trigger, trigger_id):
        """
        给定触发词位置,抽取一个事件
        """
        event = dict()
        # 取名词主语
        nsubj, conj_patterns_nsubj = self._get_core_arguments(dep_tree, 'nsubj', trigger_id)
        # 取宾语
        obj, conj_patterns_obj = self._get_core_arguments(dep_tree, 'obj', trigger_id)
        # 取间接宾语
        iobj, conj_patterns_iobj = self._get_core_arguments(dep_tree, 'iobj', trigger_id)
        # 取一般状语, 时间状语也可能出现在这其中
        obls = self._get_none_core_arguments(dep_tree, 'obl', trigger_id)
        # 取时间状语
        oblts = self._get_none_core_arguments(dep_tree, 'obl:tmod', trigger_id)
        if not nsubj:
            # 被动语态主语
            nsubj_pass, conj_patterns_nsubj_pass = self._get_core_arguments(dep_tree, 'nsubj:pass', trigger_id)
        
        # 事件整合
        event['trigger'] = trigger
        if nsubj:
            event['nsubj'] = [nsubj] + [pat[1] for pat in conj_patterns_nsubj]
        elif nsubj_pass:
            event['nsubj:pass'] = [nsubj_pass] + [pat[1] for pat in conj_patterns_nsubj_pass]
        if obj:
            event['obj'] = [obj] + [pat[1] for pat in conj_patterns_obj]
        if iobj:
            event['iobj'] = [iobj] + [pat[1] for pat in conj_patterns_iobj]
        if obls:
            for obl in obls:
                type_ = self._adjust_obl_type_with_ner(obl)
                if type_ not in event:
                    event[type_] = [obl]
                else:
                    event[type_].append(obl)
        if oblts:
            if 'obl:tmod' in event:
                event['obl:tmod'].extend(oblts)
            else:
                event['obl:tmod'] = oblts
            
        return event
    
    def _extract_clausal_event(self, dep_tree, trigger, trigger_id, sent_id, clausal_type='ccomp'):
        """
        抽取从句
        """
        clausal_events = []
        assert clausal_type in ['ccomp', 'xcomp']
        comp_patterns = self._extract_pattern(dep_tree, clausal_type, trigger_id)
        for comp_trigger_id, comp_trigger in comp_patterns:
            comp_trigger = self._fill_trigger(dep_tree, comp_trigger_id, comp_trigger)
            clausal_event = self._extract_event(dep_tree, comp_trigger, comp_trigger_id)
            clausal_event['head_trigger'] = trigger
            clausal_event['sent_id'] = sent_id
            
            clausal_events.append(clausal_event)
        
        return clausal_events
    
    def __call__(self, text, debug=None, return_dict=False, del_useless=True):
        """
        获取事件列表
        :param text: str
        :param debug: str: debug的依存关系类型
        :param return_dict: bool: 是否返回dict
        :param del_useless: bool: 是否删除无用key
        :return event_list: list of dict
        """
        if debug:
            self.debug_pat_name = debug
        event_list = []
        dep_res = self.get_dep_res(text)
        sents = [dep_tree['sent'] for dep_tree in dep_res]
        sent_tokens = []
        
        # 对每一句话进行抽取, 不考虑从句的话,每一句话一个trigger,生成一个事件
        for sent_id, dep_tree in enumerate(dep_res):
            dep_tree = dep_tree['dep']
            self.tokens = [node['words'] for node in dep_tree]
            sent_tokens.append(self.tokens)
            
            # 取主动词做trigger
            trigger_id, trigger = self._extract_trigger(dep_tree)
            trigger['tri_dep_id'] = trigger_id
            if not trigger:
                continue
            trigger = self._fill_trigger(dep_tree, trigger_id, trigger)
            event = self._extract_event(dep_tree, trigger, trigger_id)
            event['sent_id'] = sent_id
            
            # 找从句事件
            clausal_comp = self._extract_clausal_event(dep_tree, trigger, trigger_id, sent_id, 'ccomp')
            open_clausal_comp = self._extract_clausal_event(dep_tree, trigger, trigger_id, sent_id, 'xcomp')
            if len(clausal_comp):
                event['clausal_comp'] = clausal_comp
            if len(open_clausal_comp):
                event['open_clausal_comp'] = open_clausal_comp
                
            event_list.append(event)
            
            # 找trigger的并列事件
            conj_patterns = self._extract_pattern(dep_tree, 'conj', trigger_id)
            # 无连词并列事件
            parataxis_patterns = self._extract_pattern(dep_tree, 'parataxis', trigger_id)
            conj_patterns.extend(parataxis_patterns)
            for conj_trigger_id, conj_trigger in conj_patterns:
                event = self._extract_event(dep_tree, conj_trigger, conj_trigger_id)
                event['sent_id'] = sent_id
                # 找并列事件的从句事件
                clausal_comp_conj = self._extract_clausal_event(dep_tree, conj_trigger, conj_trigger_id, sent_id, 'ccomp')
                open_clausal_comp_conj = self._extract_clausal_event(dep_tree, conj_trigger, conj_trigger_id, sent_id, 'xcomp')
                if len(clausal_comp_conj):
                    event['clausal_comp'] = clausal_comp_conj
                if len(open_clausal_comp_conj):
                    event['open_clausal_comp'] = open_clausal_comp_conj
                    
                event_list.append(event)
                
        if del_useless:
            event_list = self.del_useless_key(event_list)
        event_list = [event for event in event_list if len(event) > 2]
        
        if return_dict:
            return sents, sent_tokens, event_list
        
        else:
            return EventResult(sents, sent_tokens, event_list, dep_res)
                
    def del_useless_key(self, event_list):
        """
        删除无关键值对
        """
        for event in event_list:
            event['trigger'].pop('head')
            event['trigger'].pop('dep_relation')
            event['trigger'].pop('id')
            for k in event:
                if k == 'clausal_comp':
                    event[k] = self.del_useless_key(event[k])
                if k == 'sent_id':
                    continue
                elif k == 'trigger':
                    event[k].pop('head_id')
                    event[k].pop('ner')
                    continue
                elif k == 'head_trigger':
                    continue
                for node in event[k]:
                    for useless_key in ['head', 'dep_relation', 'id', 'lemma', 'upos', 'ner']:  # 'head_id'
                        if useless_key in node:
                            node.pop(useless_key)
        
        return event_list

在实例化这个类之前,我们需要先实例化一个stanza模型,如果没有安装stanza的话需要先安装一下,本文不做具体的介绍。

import stanza
nlp = stanza.Pipeline(lang='en', dir='Stanza_En_v1/Stanza_En_model/models', use_gpu=True)
# 这里的这个路径dir是下载的stanza模型的路径
# use_gpu是控制是否使用gpu,stanza本身的封装没有提供选择哪块卡,如果想指定具体哪快卡可以直接问我

为了更好地显示事件抽取的结果,我定义了一个结果类:

class EventResult:
    """
    【事件抽取结果显示】
    """
    def __init__(self, sents, sent_tokens, event_list, dep_res):
        self.sents = sents
        self.sent_tokens = sent_tokens
        self.dict = event_list
        self.dep_res = dep_res
        self.label2color = {
    
    
            'nsubj': '\033[1;34m',
            'nsubj:pass': '\033[1;34m',
            'trigger': '\033[1;35m',
            'obj': '\033[1;33m',
            'iobj': '\033[1;33m',
            'obl': '\033[1;32m',
            'obl:tmod': '\033[1;32m',
        }
        self.label2color4 = {
    
    
            'nsubj': '\033[4;34m',
            'nsubj:pass': '\033[4;34m',
            'trigger': '\033[4;35m',
            'obj': '\033[4;33m',
            'iobj': '\033[4;33m',
            'obl': '\033[4;32m',
            'obl:tmod': '\033[4;32m',
        }
        
    def to_dict(self):
        return self.dict
    
    def __len__(self):
        return len(self.dict)
    
    def _to_triples(self, event, is_clausal=False):
        argument_triples = []
        for k in event:
            if k == 'sent_id':
                continue
            elif k == 'head_trigger':
                continue
            elif k == 'clausal_comp' or k == 'open_clausal_comp':
                for clausal in event[k]:
                    argument_triples += self._to_triples(clausal, is_clausal=True)
            elif k == 'trigger':
                argument_triples.append([k, event[k]['token_start'], event[k]['token_end'], event['sent_id'], is_clausal])
            else:
                try:
                    for arg in event[k]:
                        if type(arg) == str:
                            continue
                        try:
                            argument_triples.append([k, arg['token_start'], arg['token_end'], event['sent_id'], is_clausal])
                        except:
                            print(arg)
                except:
                    print(event)
                    continue
            
        return sorted(argument_triples, key=lambda x: x[1])
        
    def __repr__(self):
        res = ''
        all_triples = []
        for event in self.dict:
            triples = self._to_triples(event)
            all_triples += triples
            
        pointer = 0
        if len(all_triples):
            prev_sent_id = all_triples[0][3]
            
        for triple in all_triples:
            if triple[0] not in self.label2color:
                continue
            if not triple[4]:
                # 如果是主句
                color = self.label2color[triple[0]]
            else:
                color = self.label2color4[triple[0]]
            # 补齐上一句
            cur_sent_id = triple[3]
            if cur_sent_id != prev_sent_id:
                # 中间跳过的句子
                # for i in range(prev_sent_id+1, cur_sent_id):
                #     res += self.sents[i]
                #     res += '\n\n'
                for i in range(pointer, len(self.sent_tokens[prev_sent_id])):
                    res += self.sent_tokens[prev_sent_id][i]
                    res += ' '
                res += '\n\n'
            # 上一个论元结束到这个论元开始
            for i in range(pointer, triple[1]):
                res += (self.sent_tokens[triple[3]][i])
                res += ' '
            res += (color)
            # 当前论元
            for i in range(triple[1], triple[2]+1):
                res += (self.sent_tokens[triple[3]][i])
                res += ' '
            res += ('\033[0m')
            pointer = triple[2]+1
            prev_sent_id = cur_sent_id
        
        # 最后一个事件之后补全
        for i in range(pointer, len(self.sent_tokens[-1])):
            res += (self.sent_tokens[-1][i])
            res += ' '
            
        return res

(这个类在显示的时候好像有一点小bug,我后来懒得调了)

然后就可以利用这个stanza模型去实例化我们的事件抽取工具:

dee = DepEventExtractor(nlp)

接下来我们就可以抽取事件了:

text = '"I will make America great again!", he said.'
dee_res = dee(text, del_useless=False)
print(dee(text))

看一下显示的结果:
示例
如果想要转成结构化数据,可以直接对这个结果to_dict():

dee_res.to_dict()

'''
[{'trigger': {'head': 'ROOT',
   'dep_relation': 'root',
   'words': 'said',
   'id': 11,
   'head_id': -1,
   'upos': 'VERB',
   'lemma': 'say',
   'ner': 'O',
   'token_start': 11,
   'token_end': 11,
   'tri_dep_id': 11},
  'nsubj': [{'head': 'said',
    'dep_relation': 'nsubj',
    'words': 'he',
    'id': 10,
    'head_id': 11,
    'upos': 'PRON',
    'lemma': 'he',
    'ner': 'O',
    'token_start': 10,
    'token_end': 10}],
  'sent_id': 0,
  'clausal_comp': [{'trigger': {'head': 'said',
     'dep_relation': 'ccomp',
     'words': 'will make',
     'id': 3,
     'head_id': 11,
     'upos': 'VERB',
     'lemma': 'make',
     'ner': 'O',
     'token_start': 2,
     'token_end': 3},
    'nsubj': [{'head': 'make',
      'dep_relation': 'nsubj',
      'words': 'I',
      'id': 1,
      'head_id': 3,
      'upos': 'PRON',
      'lemma': 'I',
      'ner': 'O',
      'token_start': 1,
      'token_end': 1}],
    'obj': [{'head': 'make',
      'dep_relation': 'obj',
      'words': 'America',
      'id': 4,
      'head_id': 3,
      'upos': 'PROPN',
      'lemma': 'America',
      'ner': 'S-GPE',
      'token_start': 4,
      'token_end': 4}],
    'head_trigger': {'head': 'ROOT',
     'dep_relation': 'root',
     'words': 'said',
     'id': 11,
     'head_id': -1,
     'upos': 'VERB',
     'lemma': 'say',
     'ner': 'O',
     'token_start': 11,
     'token_end': 11,
     'tri_dep_id': 11},
    'sent_id': 0,
    'words': 'I will make America great again ! " ,',
    'token_start': 1,
    'token_end': 9}],
  'grown_pattern': [{'trigger': {'head': 'said',
     'dep_relation': 'ccomp',
     'words': 'will make',
     'id': 3,
     'head_id': 11,
     'upos': 'VERB',
     'lemma': 'make',
     'ner': 'O',
     'token_start': 2,
     'token_end': 3},
    'nsubj': [{'head': 'make',
      'dep_relation': 'nsubj',
      'words': 'I',
      'id': 1,
      'head_id': 3,
      'upos': 'PRON',
      'lemma': 'I',
      'ner': 'O',
      'token_start': 1,
      'token_end': 1}],
    'obj': [{'head': 'make',
      'dep_relation': 'obj',
      'words': 'America',
      'id': 4,
      'head_id': 3,
      'upos': 'PROPN',
      'lemma': 'America',
      'ner': 'S-GPE',
      'token_start': 4,
      'token_end': 4}],
    'head_trigger': {'head': 'ROOT',
     'dep_relation': 'root',
     'words': 'said',
     'id': 11,
     'head_id': -1,
     'upos': 'VERB',
     'lemma': 'say',
     'ner': 'O',
     'token_start': 11,
     'token_end': 11,
     'tri_dep_id': 11},
    'sent_id': 0,
    'words': 'I will make America great again ! " ,',
    'token_start': 1,
    'token_end': 9}]}]
'''

2. 抽取特定类型的事件

如果想要抽取某些类型的事件,那么可以通过指定触发词的方法实现。

首先写一个事件类型类:

class EventType:
    """
    【事件类型】
    """
    def __init__(self, event_name, domain_name, trigger_words):
        """
        :param event_name: str: 事件类型名称
        :param domain_name: str: 事件领域名称
        ;param trigger_words: list: 事件触发词列表
        """
        self.event_name = event_name
        self.domain_name = domain_name
        self.trigger_words = trigger_words
        
    def add_trigger(self, trigger):
        self.trigger_words.append(trigger)
        
    def remove_trigger(self, trigger):
        if trigger not in self.trigger_words:
            print('Trigger {} not in `trigger_words`.'.format(trigger))
        else:
            self.trigger_words = self.trigger_words.remove(trigger)
            
    def get_trigger_lemmas(self):
        pass

以言论事件为例,通过给定一些触发词,实例化这个类:

speech_words = ['say', 'speak', 'talk', 'stress', 'report', 'tell', 'ask', 'warn', 'note', 'state', 'deny', 
                'explain', 'insist', 'suggest', 'acknowledge', 'believe', 'confirm', 'think', 'vow']
speech_event = EventType(event_name='speech', domain_name='common', trigger_words=speech_words)

然后写一个指定事件类型的抽取类:

class SpecificEventExtractor:
    """
    【特定类型事件抽取】
    从依存关系事件抽取结果中,抽取特定的事件
    在dep抽取完之后使用
    Example:
        # (1) Create a dep model and get the dep result
        dee = DepEventExtractor(nlp)
        text = "China's defense minister says resolving 'Taiwan question' is national priority"
        dee_res = dee(text, del_useless=False)
        # (2) Define a specific event type, for example, speech
        speech_event = EventType(event_name='speech', domain_name='common', trigger_words=speech_words)
        # (3) Create a specific extractor and get all events of this type
        see = SpecificEventExtractor(speech_event, dee_res)
        spcf_events = see.get_spcf_events()
        spcf_events_exp = see.get_expanded_events()
        # (4) Show the result
        print(spcf_events_exp.to_dict()[0]['grown_pattern'][0]['words'])
    ---------------
    ver: 2021-11-16
    by: changhongyu
    """
    def __init__(self, event_type, dee_result):
        """
        :param event_type: EventType: 事件类型类
        :param dee_result: EventResult: 事件抽取的结果
        """
        self.event_type = event_type
        self.dee_result = dee_result
        
    def get_spcf_events(self):
        """
        获取事件
        :return spcf_event_list: list
        """
        spcf_event_list = []
        for event in self.dee_result.to_dict():
            if event['trigger']['lemma'] in self.event_type.trigger_words:
                spcf_event_list.append(event)
                    
        return EventResult(self.dee_result.sents, self.dee_result.sent_tokens, spcf_event_list, self.dee_result.dep_res)
    
    def get_expanded_events(self):
        """
        以事件触发词作为根节点对事件宾语进行扩展
        """
        spcf_events = self.get_spcf_events()
        for event in spcf_events.to_dict():
            if 'clausal_comp' not in event:
                continue
            event['grown_pattern'] = [self._node_grow(ccomp, ccomp['trigger']['id'], event['sent_id']) \
                            for ccomp in event['clausal_comp']]
            # print(event['trigger'])
            
        return spcf_events
            
    def _node_grow(self, node, node_id, sent_id):
        """
        将一个节点扩展到所有子节点
        """
        if 'words' not in node:
            node['words'] = node['trigger']['words']
        if 'token_start' not in node:
            node['token_start'] = node['trigger']['token_start']
        if 'token_end' not in node:
            node['token_end'] = node['trigger']['token_end']
            
        cur_sent = self.dee_result.sents[sent_id]
        for child_node in self.dee_result.dep_res[sent_id]['dep']:
            if child_node['head_id'] == node_id and child_node['words'] != node['words']:
                
                if 'token_start' not in child_node:
                    child_node['token_start'] = child_node['id']
                if 'token_end' not in child_node:
                    child_node['token_end'] = child_node['id']
                
                child_node = self._node_grow(child_node, child_node['id'], sent_id)
                
                node['token_start'] = min(node['token_start'], child_node['token_start'])
                node['token_end'] = max(node['token_end'], child_node['token_end'])
               
                node['words'] = ''.join(tok + ' ' for tok in self.dee_result.sent_tokens[sent_id][node['token_start']: node['token_end']+1])[: -1]
                
        return node

在使用这个类的时候需要传入第1节中写的那个类的抽取结果,然后这个类对结果进行过滤:

text = "China's defense minister says resolving 'Taiwan question' is national priority"
dee_res = dee(text, del_useless=False)
print(dee(text))   # 这里显示可能是有问题的,实际结果以最终to_dict拿到的为准
see = SpecificEventExtractor(speech_event, dee_res)
spcf_events = see.get_expanded_events()

# 查看言论的内容:
spcf_events.to_dict()[0]['grown_pattern'][0]['words']
# "resolving ' Taiwan question ' is national priority"

3. 结语

本文介绍了我自己实现的一个基于句法规则的弱监督事件抽取模型,可以用来生成一部分训练数据,但是不建议直接把它当作一个成熟的抽取模型来使用。如果对这个工具的使用有任何疑问,可以给我留言,如果想在这个基础上继续完善,欢迎对它进行改造。

本文到此为止就结束了。在今后的博客中,我还准备了很多我自己编写、整理的原创作品,期待的话请多多为我投币吧。

猜你喜欢

转载自blog.csdn.net/weixin_44826203/article/details/126653155
今日推荐