SQuAD 数据预处理(1)

SQuAD 数据预处理


前言

今天开始学习QANet,下午看了数据处理部分,翻译和看了github的注释和代码,在这里记录下来

机器翻译+纠正,一定有很多错误的,想要下载这个教程的可以点这个GitHub
https://github.com/kushalj001/pytorch-question-answering
,包括阅读理解DrQA,BiDAF和QANet模型的notebook,写的很清楚,很适合已经学习了解pytorch和读了部分阅读理解模型论文的研究者,进一步学习和复现模型


一、NLP Preprocessing Pipeline for QA

结合SQuAD 数据预处理(2)可以理解每个函数怎么样起作用SQuAD 数据预处理(2)

这篇notebook的目的

  1. 从头开始搭建一个问答系统,涉及到很多数据预处理代码。这篇具有的功能有,这些步骤在许多NLP任务都是常见的。我不使用任何更高级别的库来创建批,词汇表,数据集,数据加载器等。我的预处理管道只是使用spacy进行tokenization
  2. 我还尝试了使用torchtext进行相同的预处理步骤。在BiDAF的情况下,使用torchtext会更快收敛。我不知道为什么会发生这种情况,我没有torchtext的代码使用了什么底层优化。我还在努力更新
  3. 该notebook的预处理部分包含在脚本preprocess.py中

二、使用步骤

1.引入库

import torch
import numpy as np
import pandas as pd
import pickle
import re, os, string, typing, gc, json
import spacy
from collections import Counter
nlp = spacy.load('en_core_web_sm')
# nlp = spacy.load('en')

2.读入数据

def load_json(path):
    '''
    Loads the JSON file of the Squad dataset.
    Returns the json object of the dataset.
    加载一个SQuAD的JSON文件
    返回数据集的json对象
    '''
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    print("Length of data: ", len(data['data']))
    print("Data Keys: ", data['data'][0].keys())
    print("Title: ", data['data'][0]['title'])
    
    return data

SQuAD的文件结构
在这里插入图片描述

3.解析数据

def parse_data(data:dict)->list:
    '''
    Parses the JSON file of Squad dataset by looping through the
    keys and values and returns a list of dictionaries with
    context, query and label triplets being the keys of each dict.
    通过循环遍历键和值来解析Squad 数据集的JSON文件,
    并返回一个字典列表,其中context、query和label三元组是每个dict的键。
    '''
    data = data['data']    # data包含442个article有442个title
    qa_list = []

    for paragraphs in data:
        # 遍历title

        for para in paragraphs['paragraphs']:
            context = para['context']

            for qa in para['qas']:
                
                id = qa['id']
                question = qa['question']
                
                for ans in qa['answers']:
                    answer = ans['text']
                    ans_start = ans['answer_start']
                    ans_end = ans_start + len(answer)
                    # id,context,question,ans_start,ans_end,answer
                    # 信息全部存储到qa_dict
                    qa_dict = {
    
    }
                    qa_dict['id'] = id
                    qa_dict['context'] = context
                    qa_dict['question'] = question
                    qa_dict['label'] = [ans_start, ans_end]

                    qa_dict['answer'] = answer
                    qa_list.append(qa_dict)    

    return qa_list

4.删除过长的文档

def filter_large_examples(df):
    '''
    过滤太长的例子
    Returns ids of examples where context lengths, query lengths and answer lengths are
    above a particular threshold. These ids can then be dropped from the dataframe. 
    This is explicitly mentioned in QANet but can be done for other models as well.
    返回context长度、query长度和answer长度超过特定阈值的示例的id.
    然后可以从数据帧中删除这些id。
    QANet中明确提到了这一点,但其他模型也可以这样做。
    '''
    
    ctx_lens = []
    query_lens = []
    ans_lens = []
    for index, row in df.iterrows():    # 按照行进行遍历
        ctx_tokens = [w.text for w in nlp(row.context, disable=['parser','ner','tagger'])]
        if len(ctx_tokens)>400:
            ctx_lens.append(row.name)

        query_tokens = [w.text for w in nlp(row.question, disable=['parser','tagger','ner'])]
        if len(query_tokens)>50:
            query_lens.append(row.name)

        ans_tokens = [w.text for w in nlp(row.answer, disable=['parser','tagger','ner'])]
        if len(ans_tokens)>30:
            ans_lens.append(row.name)

        assert row.name == index
    
    return set(ans_lens + ctx_lens + query_lens)

5.创建词表

def gather_text_for_vocab(dfs:list):
    '''
    Gathers text from contexts and questions to build a vocabulary.
    从context和questions中收集文本,建立词汇表。
    :param dfs: list of dataframes of SQUAD dataset.
    :returns: list of contexts and questions
    '''
    
    text = []
    total = 0
    for df in dfs:
        unique_contexts = list(df.context.unique())    # 去重
        unique_questions = list(df.question.unique())
        total += df.context.nunique() + df.question.nunique()    # 返回唯一值的个数
        text.extend(unique_contexts + unique_questions)    # 在列表末尾一次性追加另一个序列中的多个值
    
    assert len(text) == total
    
    return text
def build_word_vocab(vocab_text):
    '''
    Builds a word-level vocabulary from the given text.
    # 建立一个word-level水平的词汇表
    :param list vocab_text: list of contexts and questions
    :returns 
        dict word2idx: word to index mapping of words
        dict idx2word: integer to word mapping
        list word_vocab: list of words sorted by frequency
    :返回
        word2idx字典: 单词到索引
        idx2word字典: 索引到单词
        word_vocab:根据词频排序的单词列表
    '''
    words = []
    for sent in vocab_text:
        for word in nlp(sent, disable=['parser','tagger','ner']):
            words.append(word.text)
    
    # 所有单词,按照词频统计
    word_counter = Counter(words)
    word_vocab = sorted(word_counter, key=word_counter.get, reverse=True)
    print(f"raw-vocab: {len(word_vocab)}")
    #word_vocab = list(set(word_vocab).intersection(set(glove_words)))
    print(f"glove-vocab: {len(word_vocab)}")
    word_vocab.insert(0, '<unk>')
    word_vocab.insert(1, '<pad>')
    print(f"vocab-length: {len(word_vocab)}")
    word2idx = {
    
    word:idx for idx, word in enumerate(word_vocab)}
    print(f"word2idx-length: {len(word2idx)}")
    idx2word = {
    
    v:k for k,v in word2idx.items()}
    
    
    return word2idx, idx2word, word_vocab
def build_char_vocab(vocab_text):
    '''
    Builds a character-level vocabulary from the given text.
    
    :param list vocab_text: list of contexts and questions
    :returns 
        dict char2idx: character to index mapping of words
        list char_vocab: list of characters sorted by frequency
    '''
    
    chars = []
    for sent in vocab_text:
        for ch in sent:
            chars.append(ch)

    char_counter = Counter(chars)
    char_vocab = sorted(char_counter, key=char_counter.get, reverse=True)
    print(f"raw-char-vocab: {len(char_vocab)}")
    high_freq_char = [char for char, count in char_counter.items() if count>=20]
    char_vocab = list(set(char_vocab).intersection(set(high_freq_char)))
    print(f"char-vocab-intersect: {len(char_vocab)}")
    char_vocab.insert(0,'<unk>')
    char_vocab.insert(1,'<pad>')
    char2idx = {
    
    char:idx for idx, char in enumerate(char_vocab)}
    print(f"char2idx-length: {len(char2idx)}")
    
    return char2idx, char_vocab
def context_to_ids(text, word2idx):
    '''
    Converts context text to their respective ids by mapping each word
    using word2idx. Input text is tokenized using spacy tokenizer first.
    通过使用word2idx映射每个单词,将context text转换为它们各自的id。
    输入文本首先使用tokenize进行tokenizer。
    :param str text: context text to be converted
    :returns list context_ids: list of mapped ids
    
    :raises assertion error: sanity check
    
    '''

    context_tokens = [w.text for w in nlp(text, disable=['parser','tagger','ner'])]
    context_ids = [word2idx[word] for word in context_tokens]
    
    assert len(context_ids) == len(context_tokens)
    return context_ids
    
def question_to_ids(text, word2idx):
    '''
    Converts question text to their respective ids by mapping each word
    using word2idx. Input text is tokenized using spacy tokenizer first.
    通过使用word2idx映射每个单词,将question text转换为它们各自的id。
    输入文本首先使用tokenize进行tokenizer。
    :param str text: question text to be converted
    :returns list context_ids: list of mapped ids
    
    :raises assertion error: sanity check
    
    '''

    question_tokens = [w.text for w in nlp(text, disable=['parser','tagger','ner'])]
    question_ids = [word2idx[word] for word in question_tokens]
    
    assert len(question_ids) == len(question_tokens)
    return question_ids

6.确保标签正确

这部分很多没看懂,也直接贴英文的过来吧
Purpose behind this test was to ensure that the label for each example was correct.
Two ways to test this.

  • First is to calculate spans of the context and check if start and end indices from the label are present in the calculated spans. Index value of the example which fails the start and end tests are appended in separate lists.
  • Second is to get the start and end indexes of the answer in the context_ids list. Get the ids corresponding to those positions, convert them to string using word2idx and compare them with the start and end tokens from the given answer. Examples which fail this test have their position added to a list.
  • The reason why some examples fail here is largely due to the absence of a ' ' or a space before and after the answer in the context. There are some spans that the tokenizer fails to capture or is simply a case where the example is not cleaned.

在测试评分
这个测试的目的是确保每个例子的标签是正确的。
有两种测试方法。

  • 首先是计算上下文的跨度,并检查标签的开始和结束索引是否存在于计算的跨度中。未通过开始和结束测试的示例的索引值将追加到单独的列表中。
  • 第二步是在context_ids列表中获取答案的开始和结束索引。获取与这些位置对应的id,使用’word2idx’将它们转换为字符串,并与给定答案中的开始和结束标记进行比较。未通过测试的例子会被添加到一个列表中。
  • 一些例子在这里失败的原因很大程度上是因为在context中答案的前后没有“’”或空格。有一些区域标记器无法捕获,或者只是例子中没有被清理的情况。
def test_indices(df, idx2word):
    '''
    Performs the tests mentioned above. This method also gets the start and end of the answers
    with respect to the context_ids for each example.
    执行上面提到的测试。
    这个方法还获取与每个示例的context_id相关的答案的开始和结束。
    :param dataframe df: SQUAD df
    :returns
        list start_value_error: example idx where the start idx is not found in the start spans
                                of the text
                                示例idx,在文本的起始区域中找不到起始idx
        list end_value_error: example idx where the end idx is not found in the end spans
                              of the text
        list assert_error: examples that fail assertion errors. A majority are due to the above errors
        
    '''

    start_value_error = []
    end_value_error = []
    assert_error = []
    for index, row in df.iterrows():

        answer_tokens = [w.text for w in nlp(row['answer'], disable=['parser','tagger','ner'])]

        start_token = answer_tokens[0]
        end_token = answer_tokens[-1]
        
        context_span  = [(word.idx, word.idx + len(word.text)) 
                         for word in nlp(row['context'], disable=['parser','tagger','ner'])]
        # .idx 是按照字符进行分割,返回的是单词首字符的位置
        starts, ends = zip(*context_span)

        answer_start, answer_end = row['label']

        try:
            start_idx = starts.index(answer_start)
        except:
            start_value_error.append(index)
        try:
            end_idx  = ends.index(answer_end)
        except:
            end_value_error.append(index)

        try:
            assert idx2word[row['context_ids'][start_idx]] == answer_tokens[0]
            assert idx2word[row['context_ids'][end_idx]] == answer_tokens[-1]
        except:
            assert_error.append(index)


    return start_value_error, end_value_error, assert_error
def get_error_indices(df, idx2word):
    '''
    Gets error indices from the method above and returns a 
    set of those indices.
    从上面的方法中获取错误索引,并返回一组这些索引。
    '''
    
    start_value_error, end_value_error, assert_error = test_indices(df)
    err_idx = start_value_error + end_value_error + assert_error
    err_idx = set(err_idx)
    print(f"Error indices: {len(err_idx)}")
    
    return err_idx
def index_answer(row, idx2word):
    '''
    Takes in a row of the dataframe or one training example and
    returns a tuple of start and end positions of answer by calculating 
    spans.
    获取数据帧的一行或一个训练示例,并通过计算跨度返回答案的起始和结束位置的元组。
    '''
    
    context_span = [(word.idx, word.idx + len(word.text)) for word in nlp(row.context, disable=['parser','tagger','ner'])]
    starts, ends = zip(*context_span)
    
    answer_start, answer_end = row.label
    start_idx = starts.index(answer_start)
 
    end_idx  = ends.index(answer_end)
    
    ans_toks = [w.text for w in nlp(row.answer,disable=['parser','tagger','ner'])]
    ans_start = ans_toks[0]
    ans_end = ans_toks[-1]
    assert idx2word[row.context_ids[start_idx]] == ans_start
    assert idx2word[row.context_ids[end_idx]] == ans_end
    
    return [start_idx, end_idx]

总结

这就是粘贴了一些代码+自己的翻译,可读性不好,想要看的可以直接在github下载看

猜你喜欢

转载自blog.csdn.net/qq_42388742/article/details/112104456
今日推荐