NPL Tools - A Multi-Model Voter for NER Tasks

0. Introduction

Recently, I am doing tasks related to Named Entity Recognition (NER). I am working on an integrated model, which involves the fusion of multiple model results. I need to use some method to vote on the results predicted by multiple models to get the final result. Since the task is flat NER, it is necessary to avoid the problem of entity overlap during the voting process.

In order to achieve this function, I wrote a voter class and recorded it so that it can be used again when needed in the future.

1. Data format

Assume that the results predicted by all k models are saved as result in list format. The length of result is k, and each element corresponds to a dict, which records the prediction results of the model. The key of dict is the category name, and the value is all entities detected as this category.

result = [{'类别1': [],
  '类别2': [],
  '类别3': [[25, 31]],
  '类别4': [[118, 123]],
  '类别5': [[70, 71], [94, 99]],
  '类别6': []},
 {'类别1': [[182, 183]],
  '类别2': [],
  '类别3': [[25, 31], [44, 52], [79, 92]],
  '类别4': [[118, 123]],
  '类别5': [[70, 71], [94, 99]],
  '类别6': []},
  ……
 {'类别1': [],
  '类别2': [],
  '类别3': [[25, 31], [44, 52]],
  '类别4': [[118, 123]],
  '类别5': [[44, 52], [70, 71], [96, 99]],
  '类别6': []}]

2. Voting rules

First, let’s review how the bagging strategy works in general classification tasks. The simplest rule is that the minority obeys the majority rule. For example, among 10 models, if 8 of them are classified into class A and 2 are classified into class B, then the final result will be judged as class A. However, in the NER task, since the span of the entity is involved, there is no way to mark the entity with a simple voting method, because there may be an entity near a certain position, but it is necessary to judge the starting position of the entity at ① this position, and ② at this position. The category to which the entity belongs.

For example, in a certain sentence, Model 1 recognizes " Pink Starfish Pai Daxing " as a person category, Model 2 recognizes " Sea Star Pai Daxing " as a person category, and Model 3 recognizes " Pink Starfish " as a person category, so how to determine the final voting result?

So I designed a voting rule, which may still be unreasonable, but it can output a logically complete and more reliable result.

Rules & process :
1. Generate initialization : read the results of all models, traverse each identified entity (regardless of type), record all the start and end positions, and generate an initialized counting 'dictionary', the key of the counting 'dictionary' is this position, and the value is the number of times this position appears as the start or end position. Since the dict object in python is immutable during iteration, a list is used to simulate the 'dictionary', the index of the list simulates the key of the 'dictionary', and then a mapping from index to position is established.
2. Count the number of occurrences : read the results again, count all the positions that appear in the 'dictionary' for initialization, record the number of times this position has appeared as the start and end positions of all types in all models (later, this number is changed to weighting, and the weight is the value of f1 for each model), and fill in the value of the 'dictionary'. So far, each position in the 'dictionary' corresponds to a p*2 array, where p is the number of entity categories.
3. Find the first significant position : In the counting 'dictionary' generated above, look for the first significant position, and if it is greater than the 'significant threshold', match the corresponding start or end position. If the first significant position is the start position, then go to the right to find the end position of this entity; if it is the end position, then go to the left to find the start position of this entity. After the first significant position is found, the value of this position in the counting 'dictionary' is set to 0.
4. Match the first significant position : Take looking for the end position to the right as an example to illustrate the matching rules. This matching position should satisfy: (1) the generated span cannot overlap with the existing span; (2) the matching position should be the most significant position among all the types (similar to the first significant bit found in 3); (3) the count value of the matching position satisfies the 'significant threshold'. After the match is successful, the count value of the matching position in the count 'dictionary' is set to 0, and the newly generated entity span is added to the existing span.
5. Loop: Continue to execute steps 3 and 4, search for the first significant digit in the remaining positions and match the entity, until the significance of the first significant digit is less than the set significance threshold, then jump out of the loop.

3. Code implementation

import numpy as np
import copy

class Voter():
    def __init__(self, threshold, results):
        self.threshold = threshold   #  显著阈值
        self.results = results         #  所有模型的结果
        self.spans = []              #  现有实体的所有span
    
    
    def predicate2id(self, predicate):
        pr2id = {'类别1':0, '类别2':1, '类别3':2, '类别4':3}
        return pr2id[predicate]
    
    
    def id2predicate(self, id):
        id2pr = {0:'类别1', 1:'类别2', 2:'类别3', 3:'类别4'}
        return id2pr[id]
    
    
    def model_point(self, model_id):
    	'''
    	这里记录的是所有模型的f1的值,作为权重,注意修改
    	'''
        point = [0.6153846153847338, 0.6177606177607161, 0.6169014084508121, 0.5877318116976925, 0.573333333333447,
                0.6627043090639932, 0.630225080385971, 0.6635514018692636, 0.6210720887247242]
        return point[model_id]
    
    def sub_of(self, sub_inter, inter):
        '''
        辅助工具:判断一个区间是不是另一个区间的子区间
        '''
        a1, a2 = sub_inter[0], sub_inter[1]
        # print(a1)
        # print(a2)
        if a1 > a2:
            return False
        if len(inter):
            b1, b2 = inter[0], inter[1]
            assert b1 < b2
        else:
            b1, b2 = 0, 0
        if a1 >= b1 and a2 <= b2:
            return True
        else:
            return False
        
        
    def find_all_spans_by_cls(self, cls):
        '''
        辅助工具:获取所有模型中某类别所有实体对应区间
        '''
        all_spans_by_cls = []
        results = self.results
        for result in results:   # 对每一个模型的结果
            for span in result[self.id2predicate(cls)]:   # 对当前模型结果中这一类的所有span
                if span not in all_spans_by_cls:   # 如果不在已经选出来的span中
                    all_spans_by_cls.append(span)
        return all_spans_by_cls
    
    
    def generate_init(self):
        '''
        生成初始化字典
        由于字典在迭代过程中不能改变其中数值
        所以将计数的存储方式改为list
        并建立一个从position到list的index的映射,模拟字典的key
        '''
        count_dict = []
        key2index = {}   # 这两个映射一旦生成了就不用在动它了
        index2key = {}
        i = 0
        for model_res in self.results:
            for key in model_res:   # 对每一类
                # print(model_res[key])   # 每一类对应的实体
                for v in model_res[key]:    # 每一类对应的每一个实体
                    # print(v)
                    for vv in v:             # 每一类对应的每一个实体对应的start和end
                        # print(vv)
                        # print(count_dict)
                        if str(vv) not in key2index.keys():
                            key2index[str(vv)] = i
                            index2key[i] = str(vv)
                            count_dict.append(np.zeros((4,2)))
                            i += 1
        return count_dict, key2index, index2key
        

    def fill_count(self):
        '''
        每个位置计数
        '''
        count_dict, key2index, index2key = self.generate_init()
        for model_id, model_res in enumerate(self.results):
            for key in model_res:
                for v in model_res[key]:  # v 是每一个实体对应的start和end的list
                    if key == '试验要素':
                        count_dict[key2index[str(v[0])]][0][0] += self.model_point(model_id)  # v的start位置的第一行第一列  代表试验要素的开始
                        count_dict[key2index[str(v[1])]][0][1] += self.model_point(model_id)  # v的end位置的第一行第二列  代表试验要素的结束
                    elif key == '性能指标':
                        count_dict[key2index[str(v[0])]][1][0] += self.model_point(model_id)
                        count_dict[key2index[str(v[1])]][1][1] += self.model_point(model_id)
                    elif key == '任务场景':
                        count_dict[key2index[str(v[0])]][2][0] += self.model_point(model_id)
                        count_dict[key2index[str(v[1])]][2][1] += self.model_point(model_id)
                    elif key == '系统组成':
                        count_dict[key2index[str(v[0])]][3][0] += self.model_point(model_id)
                        count_dict[key2index[str(v[1])]][3][1] += self.model_point(model_id)
        return count_dict
    
    
    def search_first(self, count_dict, key2index, index2key):
        '''
        寻找count_dict中出现次数最多的位置
        返回其是start还是end,其分类码,以及其对应数值
        并在count_dict中将这个位置置为0
        '''
        print('searching first...')
        
        max_pos = 0   # 当前最大计数对应位置
        max_count = 0  # 当前最大计数

        for i in range(len(count_dict)):
            pos = index2key[i]
            cur_count = np.max(count_dict[i])
            if cur_count > max_count:
                mx = np.where(count_dict[i] == cur_count)
                cls = int(mx[0])        # 对应类别编号
                se = int(mx[1])         # 对应开始结束
                max_pos = pos
                max_count = cur_count
                
        print('got max_pos: %s' %max_pos)
        print('current max_count is %s' % max_count)
        # print('remove pos: %s' %max_pos)
        count_dict[key2index[max_pos]] = np.zeros((4,2))  # 这个位置置为0
        return se, cls, int(max_pos), count_dict, max_count

    
    def search_backward(self, cls, base_pos, count_dict, spans, key2index, index2key):
        '''
        当search_first函数搜索到的是se为1(end),则向后找start
        cls:search_first搜索到的cls
        base_pos:基准位置
        返回:搜索到的最匹配位置
        '''
        print('----------')
        print('searching backward...')
        max_pos = -1
        max_count = 0
        base_pos = int(base_pos)
        print('match for pos: %s' %base_pos)
        # print(spans)
        span_to_append = []
        
        for i in range(len(count_dict)):
            '''
            规则:
            1.所选点在base之前
            2.所选点在潜在点集中(已满足)
            3.所选点与base之间所有点都在至少一个模型的实体结果中
            4.所选点在上一个同类span的end之后(当前span不是第一个时,才判断规则4)
            '''
            pos = index2key[i]
            
            # tmp_span用于判断base在已有span中的位置
            tmp_span = copy.copy(spans)
            if [base_pos, base_pos] not in tmp_span:
                tmp_span.append([base_pos, base_pos])
            # print([base_pos,base_pos])
            # print(tmp_span)
            tmp_span.sort()
            
            # 开始对规则3进行判断
            all_spans_by_cls = self.find_all_spans_by_cls(cls)
            prncp3 = False
            for span in all_spans_by_cls:     # 对每一个同类实体,判断所选区间是不是其子集
                prncp3 = prncp3 or self.sub_of([int(pos), base_pos], span)
             
            if len(spans):       # 如果spans这个时候已经是非空的
                # print('base_pos 在tmp_span中前边紧接着的span:%s' %(tmp_span[tmp_span.index([base_pos, base_pos])-1]))
                
                if tmp_span.index([base_pos, base_pos]) == 0:
                    # 如果base_pos在tmp_span中已经是第一个,前面没有了,那么就可以往前随便选
                    if int(pos) < base_pos and prncp3:
                        cur_count = count_dict[i][cls][0]
                        if cur_count > max_count:
                            max_count = cur_count
                            max_pos = int(pos)
                elif tmp_span.index([base_pos, base_pos]) > 0:
                    # 如果base在tmp中不是第一个,前面还有,那么需要保证找的匹配点在前面一个span之后(prncp4)
                    prncp4 = tmp_span[tmp_span.index([base_pos, base_pos])-1][1] < int(pos)
                    if int(pos) < base_pos and prncp3 and prncp4:   # 向前搜索,并且不在已有的span中
                        cur_count = count_dict[i][cls][0]
                        if cur_count > max_count:
                            max_count = cur_count
                            max_pos = int(pos)
            else:                                   # 初始情况下spans为空,不需要判断在不在已有的span中
                if int(pos) < base_pos and prncp3:
                    cur_count = count_dict[i][cls][0]
                    if cur_count > max_count:
                        max_count = cur_count
                        max_pos = int(pos)
                        # print(max_pos)
        if max_pos >= 0:
            print('got max_pos at %s' % max_pos)
            count_dict[key2index[str(max_pos)]] = np.zeros((4,2))   # 置为0
            # print('remove pos: %s' % max_pos)
            span_to_append = [max_pos, base_pos]   # 准备追加的span
            # print(span_to_append)
                    
        if span_to_append not in spans and len(span_to_append):
            print('doing backward append...')
            if len(spans):
                spans.sort()
                for span in spans:
                    if span[0] == span_to_append[1]+1 and span != span_to_append:      # 跟下一个span连起来了
                        span_to_append = [span_to_append[0], span[1]]   # 取首尾,中间不要
                        spans.append(span_to_append)
                        spans.remove(span)                             # 原来的删掉
                    elif span[1] == span_to_append[0]-1 and span != span_to_append:     # 跟上一个span连起来了
                        span_to_append = [span[0], span_to_append[1]]    # 取首尾
                        spans.append(span_to_append)
                        spans.remove(span)
                    else:
                        if span != span_to_append:
                            spans.append(span_to_append)                  # 没有接起来的情况,直接append
            elif len(spans) == 0:
                spans.append(span_to_append)
        # print('spans after searched backward: %s' % spans)
        return int(max_pos), count_dict, spans

    
    def search_forward(self, cls, base_pos, count_dict, spans, key2index, index2key):
        '''
        当search_first函数搜索到的是se为0(start),则向前找end
        cls:search_first搜索到的cls
        base_pos:基准位置
        返回:搜索到的最匹配位置
        '''
        print('----------')
        print('searching forward...')
        max_pos = -1
        max_count = 0
        base_pos = int(base_pos)
        # print(spans)
        print('match for pos: %s' %base_pos)
        span_to_append = []
        
        for i in range(len(count_dict)):
            '''
            规则:
            1.所选点在base之后
            2.所选点在潜在点集中(已满足)
            3.所选点与base之间所有点都在至少一个模型的实体结果中
            4.所选点在下一个同类span的start之前(当前span不是最后一个时,才判断规则4)
            '''
            pos = index2key[i]  # 找出所有潜在的pos,str类型,并对每一个pos进行循环
            
            tmp_span = copy.copy(spans)      # 复制一个spans,并把当前位置加进去,以寻找其相邻的span
            if [base_pos, base_pos] not in tmp_span:
                tmp_span.append([base_pos, base_pos])
            # print(spans)
            # print([base_pos,base_pos])
            # print(tmp_span)
            tmp_span.sort()
            
            # 开始对规则3进行判断
            all_spans_by_cls = self.find_all_spans_by_cls(cls)
            prncp3 = False
            for span in all_spans_by_cls:     # 对每一个同类实体,判断所选区间是不是其子集
                prncp3 = prncp3 or self.sub_of([base_pos, int(pos)], span)
            
            if len(spans):       # 如果spans这个时候已经是非空的
                # print(spans)
                # print('tmp_span:%s' %tmp_span)
                # print(tmp_span.index([base_pos, base_pos]))
                # print(len(tmp_span))
                if tmp_span.index([base_pos, base_pos])+1 == len(tmp_span):   
                    # base_pos是tmp_span中的最后一个,后边没有了,那么后面的所有点都可选
                    # print('后面没有了')
                    if int(pos) > base_pos and prncp3:
                        cur_count = count_dict[i][cls][1]
                        if cur_count > max_count:
                            max_count = cur_count
                            max_pos = int(pos)
                elif tmp_span.index([base_pos, base_pos])+1 < len(tmp_span):
                    # 如果base_pos后面还有别的实体,那么只能选到这个实体之前
                    # print('base_pos 在tmp_span中后边紧接着的span:%s' %(tmp_span[tmp_span.index([base_pos, base_pos])+1]))
                    prncp4 = tmp_span[tmp_span.index([base_pos, base_pos])+1][0] > int(pos)
                    if int(pos) > base_pos and prncp3 and prncp4:   # 向前搜索,并且不在已有的span中
                        cur_count = count_dict[i][cls][1]
                        if cur_count > max_count:
                            max_count = cur_count
                            max_pos = int(pos)
            else:                                   # 初始情况下spans为空,不需要判断在不在已有的span中
                if int(pos) > base_pos and prncp3:
                    cur_count = count_dict[i][cls][1]
                    if cur_count > max_count:
                        max_count = cur_count
                        max_pos = int(pos)
                        # print(max_pos)
        if max_pos >= 0:
            print('got max_pos at %s' % max_pos)
            count_dict[key2index[str(max_pos)]] = np.zeros((4,2))
            # print('remove pos: %s' % max_pos)
            span_to_append = [base_pos, max_pos]
            # print(span_to_append)
                    
        if span_to_append not in spans and len(span_to_append): # 如果准备追加的不在原有spans中
            if len(spans):   # 如果spans已有内容
                print('doing backward append...')
                spans.sort()
                for span in spans:
                    if span[0] == span_to_append[1]+1 and span != span_to_append:     # 跟下一个span连起来了
                        span_to_append = [span_to_append[0], span[1]]   # 取首尾,中间不要
                        spans.append(span_to_append)
                        spans.remove(span)                             # 原来的删掉
                    elif span[1] == span_to_append[0]-1 and span != span_to_append:     # 跟上一个span连起来了
                        span_to_append = [span[0], span_to_append[1]]    # 取首尾
                        spans.append(span_to_append)
                        spans.remove(span)
                    else:
                        if span != span_to_append:
                            spans.append(span_to_append)                  # 没有接起来的情况,直接append
            elif len(spans) == 0:      # 如果现在spans还没有内容,但是有内容可以加入
                spans.append(span_to_append)
        # print('spans after searched forward: %s' % spans)
        return int(max_pos), count_dict, spans
    
    
    def generate_res(self):
        '''
        生成最终的结果
        '''
        res = {'类别1':[], '类别2':[], '类别3':[], '类别4':[]}
        spans = self.spans
        threshold = self.threshold
        print('=======================')
        print('set threshold: %s' % threshold)
        print('=======================')
        _, key2index, index2key = self.generate_init()  # 只是为了保存两个dict
        count_dict = self.fill_count()  # 初始化
        
        while True:                   # 满足阈值条件时,一直执行,不满足时,跳出
            # cur_se, cur_cls, cur_pos, self.count_dict, max_count = self.search_first(count_dict, key2index, index2key)
            try:
                cur_se, cur_cls, cur_pos, self.count_dict, max_count = self.search_first(count_dict, key2index, index2key)
            except Exception as e:
                print(e)
                break
            if max_count < threshold:
                break
            if cur_se == 0:    # 如果找到的是一个start,接下来就找它对应的end
                cur_end, count_dict, spans = self.search_forward(cls=cur_cls, base_pos=cur_pos, count_dict=count_dict, spans=spans, key2index=key2index, index2key=index2key)
                if cur_end != -1:
                    res[self.id2predicate(cur_cls)].append([cur_pos, cur_end])       # 保存结果,最终保存的不是spans而是res
            elif cur_se == 1:    # 如果找到的是一个end,接下来就找它对应的start
                cur_start, count_dict, spans = self.search_backward(cls=cur_cls, base_pos=cur_pos, count_dict=count_dict, spans=spans, key2index=key2index, index2key=index2key)
                if cur_start != -1:
                    res[self.id2predicate(cur_cls)].append([cur_start, cur_pos])
            
        return res
            

3. How to use

First pay attention to modifying the corresponding f1 score in the model_point function, and then pay attention to the number and name of categories corresponding to your own data set.
Also, the dimension of the array generated by np.zeros should correspond to the number of categories.

V = Voter(threshold, results)
final_res = V.generate_res()

4. Other situations

There is a situation that cannot be solved in this voting rule, that is, when search_first is looking for the first significant bit, if the two positions have the same significant count, the code cannot continue. When this happens, I use the result with the largest value of f1 alone as the final result.

This blog is mainly written for myself. If you have other better voting methods, or think that my method has obvious bugs, please leave a message. If this article is helpful to you, please like it.

Guess you like

Origin blog.csdn.net/weixin_44826203/article/details/108347693