Preprocesamiento de datos SQuAD (1)

Preprocesamiento de datos SQuAD


Prefacio

Comencé a aprender QANet hoy, vi la parte de procesamiento de datos por la tarde, traduje y leí los comentarios y códigos en github, y los grabé aquí.

Traducción automática + corrección, debe haber muchos errores. Si desea descargar este tutorial, puede hacer clic en este GitHub
https://github.com/kushalj001/pytorch-question-answering
, incluida la lectura y comprensión de los cuadernos de DrQA, BiDAF y modelos QANet, y la escritura Es muy clara, muy adecuada para investigadores que ya han aprendido a comprender Pytorch y han leído algunos de los artículos del modelo de comprensión de lectura para estudiar y reproducir el modelo.


一 、 Pipeline de preprocesamiento de PNL para control de calidad

Combinado con el preprocesamiento de datos SQuAD (2) puede comprender cómo funciona cada función. Preprocesamiento de datos SQuAD (2)

El propósito de este cuaderno

  1. La construcción de un sistema de preguntas y respuestas desde cero implica mucho código de preprocesamiento de datos. Este artículo tiene características que indican que estos pasos son comunes en muchas tareas de PNL. No uso bibliotecas de nivel superior para crear lotes, glosarios, conjuntos de datos, cargadores de datos, etc. Mi canalización de preprocesamiento solo usa espacios para la tokenización
  2. También probé el mismo paso de preprocesamiento usando torchtext. En el caso de BiDAF, el uso de torchtext convergerá más rápido. No sé por qué sucede esto, no tengo optimizaciones subyacentes para el código de Torchtext. Todavía estoy intentando actualizar
  3. La parte de preprocesamiento del cuaderno está contenida en el script preprocess.py

En segundo lugar, siga los pasos

1. Importa la biblioteca

import torch
import numpy as np
import pandas as pd
import pickle
import re, os, string, typing, gc, json
import spacy
from collections import Counter
nlp = spacy.load('en_core_web_sm')
# nlp = spacy.load('en')

2. Leer los datos

def load_json(path):
    '''
    Loads the JSON file of the Squad dataset.
    Returns the json object of the dataset.
    加载一个SQuAD的JSON文件
    返回数据集的json对象
    '''
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    print("Length of data: ", len(data['data']))
    print("Data Keys: ", data['data'][0].keys())
    print("Title: ", data['data'][0]['title'])
    
    return data

Estructura de archivos SQuAD
Inserte la descripción de la imagen aquí

3. Analizar los datos

def parse_data(data:dict)->list:
    '''
    Parses the JSON file of Squad dataset by looping through the
    keys and values and returns a list of dictionaries with
    context, query and label triplets being the keys of each dict.
    通过循环遍历键和值来解析Squad 数据集的JSON文件,
    并返回一个字典列表,其中context、query和label三元组是每个dict的键。
    '''
    data = data['data']    # data包含442个article有442个title
    qa_list = []

    for paragraphs in data:
        # 遍历title

        for para in paragraphs['paragraphs']:
            context = para['context']

            for qa in para['qas']:
                
                id = qa['id']
                question = qa['question']
                
                for ans in qa['answers']:
                    answer = ans['text']
                    ans_start = ans['answer_start']
                    ans_end = ans_start + len(answer)
                    # id,context,question,ans_start,ans_end,answer
                    # 信息全部存储到qa_dict
                    qa_dict = {
    
    }
                    qa_dict['id'] = id
                    qa_dict['context'] = context
                    qa_dict['question'] = question
                    qa_dict['label'] = [ans_start, ans_end]

                    qa_dict['answer'] = answer
                    qa_list.append(qa_dict)    

    return qa_list

4. Elimina documentos que sean demasiado largos.

def filter_large_examples(df):
    '''
    过滤太长的例子
    Returns ids of examples where context lengths, query lengths and answer lengths are
    above a particular threshold. These ids can then be dropped from the dataframe. 
    This is explicitly mentioned in QANet but can be done for other models as well.
    返回context长度、query长度和answer长度超过特定阈值的示例的id.
    然后可以从数据帧中删除这些id。
    QANet中明确提到了这一点,但其他模型也可以这样做。
    '''
    
    ctx_lens = []
    query_lens = []
    ans_lens = []
    for index, row in df.iterrows():    # 按照行进行遍历
        ctx_tokens = [w.text for w in nlp(row.context, disable=['parser','ner','tagger'])]
        if len(ctx_tokens)>400:
            ctx_lens.append(row.name)

        query_tokens = [w.text for w in nlp(row.question, disable=['parser','tagger','ner'])]
        if len(query_tokens)>50:
            query_lens.append(row.name)

        ans_tokens = [w.text for w in nlp(row.answer, disable=['parser','tagger','ner'])]
        if len(ans_tokens)>30:
            ans_lens.append(row.name)

        assert row.name == index
    
    return set(ans_lens + ctx_lens + query_lens)

5. Crea un vocabulario

def gather_text_for_vocab(dfs:list):
    '''
    Gathers text from contexts and questions to build a vocabulary.
    从context和questions中收集文本,建立词汇表。
    :param dfs: list of dataframes of SQUAD dataset.
    :returns: list of contexts and questions
    '''
    
    text = []
    total = 0
    for df in dfs:
        unique_contexts = list(df.context.unique())    # 去重
        unique_questions = list(df.question.unique())
        total += df.context.nunique() + df.question.nunique()    # 返回唯一值的个数
        text.extend(unique_contexts + unique_questions)    # 在列表末尾一次性追加另一个序列中的多个值
    
    assert len(text) == total
    
    return text
def build_word_vocab(vocab_text):
    '''
    Builds a word-level vocabulary from the given text.
    # 建立一个word-level水平的词汇表
    :param list vocab_text: list of contexts and questions
    :returns 
        dict word2idx: word to index mapping of words
        dict idx2word: integer to word mapping
        list word_vocab: list of words sorted by frequency
    :返回
        word2idx字典: 单词到索引
        idx2word字典: 索引到单词
        word_vocab:根据词频排序的单词列表
    '''
    words = []
    for sent in vocab_text:
        for word in nlp(sent, disable=['parser','tagger','ner']):
            words.append(word.text)
    
    # 所有单词,按照词频统计
    word_counter = Counter(words)
    word_vocab = sorted(word_counter, key=word_counter.get, reverse=True)
    print(f"raw-vocab: {len(word_vocab)}")
    #word_vocab = list(set(word_vocab).intersection(set(glove_words)))
    print(f"glove-vocab: {len(word_vocab)}")
    word_vocab.insert(0, '<unk>')
    word_vocab.insert(1, '<pad>')
    print(f"vocab-length: {len(word_vocab)}")
    word2idx = {
    
    word:idx for idx, word in enumerate(word_vocab)}
    print(f"word2idx-length: {len(word2idx)}")
    idx2word = {
    
    v:k for k,v in word2idx.items()}
    
    
    return word2idx, idx2word, word_vocab
def build_char_vocab(vocab_text):
    '''
    Builds a character-level vocabulary from the given text.
    
    :param list vocab_text: list of contexts and questions
    :returns 
        dict char2idx: character to index mapping of words
        list char_vocab: list of characters sorted by frequency
    '''
    
    chars = []
    for sent in vocab_text:
        for ch in sent:
            chars.append(ch)

    char_counter = Counter(chars)
    char_vocab = sorted(char_counter, key=char_counter.get, reverse=True)
    print(f"raw-char-vocab: {len(char_vocab)}")
    high_freq_char = [char for char, count in char_counter.items() if count>=20]
    char_vocab = list(set(char_vocab).intersection(set(high_freq_char)))
    print(f"char-vocab-intersect: {len(char_vocab)}")
    char_vocab.insert(0,'<unk>')
    char_vocab.insert(1,'<pad>')
    char2idx = {
    
    char:idx for idx, char in enumerate(char_vocab)}
    print(f"char2idx-length: {len(char2idx)}")
    
    return char2idx, char_vocab
def context_to_ids(text, word2idx):
    '''
    Converts context text to their respective ids by mapping each word
    using word2idx. Input text is tokenized using spacy tokenizer first.
    通过使用word2idx映射每个单词,将context text转换为它们各自的id。
    输入文本首先使用tokenize进行tokenizer。
    :param str text: context text to be converted
    :returns list context_ids: list of mapped ids
    
    :raises assertion error: sanity check
    
    '''

    context_tokens = [w.text for w in nlp(text, disable=['parser','tagger','ner'])]
    context_ids = [word2idx[word] for word in context_tokens]
    
    assert len(context_ids) == len(context_tokens)
    return context_ids
    
def question_to_ids(text, word2idx):
    '''
    Converts question text to their respective ids by mapping each word
    using word2idx. Input text is tokenized using spacy tokenizer first.
    通过使用word2idx映射每个单词,将question text转换为它们各自的id。
    输入文本首先使用tokenize进行tokenizer。
    :param str text: question text to be converted
    :returns list context_ids: list of mapped ids
    
    :raises assertion error: sanity check
    
    '''

    question_tokens = [w.text for w in nlp(text, disable=['parser','tagger','ner'])]
    question_ids = [word2idx[word] for word in question_tokens]
    
    assert len(question_ids) == len(question_tokens)
    return question_ids

6.Asegúrate de que la etiqueta sea correcta

No entiendo mucho de esta parte, por lo que puedo publicar el inglés directamente. El
propósito de esta prueba era garantizar que la etiqueta de cada ejemplo fuera correcta.
Hay dos formas de probar esto.

  • Primero es calcular lapsos del contexto y verificar si los índices de inicio y final de la etiqueta están presentes en los intervalos calculados. El valor de índice del ejemplo que no supera las pruebas de inicio y finalización se adjunta en listas separadas.
  • El segundo es obtener los índices de inicio y finalización de la respuesta en la lista context_ids. Obtenga los identificadores correspondientes a esas posiciones, conviértalos en cadenas word2idxy compárelos con los tokens de inicio y fin de la respuesta dada. Los ejemplos que no superan esta prueba tienen su posición agregada a una lista.
  • La razón por la que algunos ejemplos fallan aquí se debe en gran parte a la ausencia de un ' 'o un espacio antes y después de la respuesta en el contexto. Hay algunos intervalos que el tokenizador no logra capturar o es simplemente un caso en el que el ejemplo no se limpia.

Puntuación en
la prueba El propósito de esta prueba es asegurar que la etiqueta de cada ejemplo sea correcta.
Hay dos métodos de prueba.

  • La primera es calcular el intervalo del contexto y comprobar si los índices de inicio y finalización de la etiqueta existen en el intervalo calculado. Los valores de índice de los ejemplos que fallaron en las pruebas de inicio y finalización se agregarán a una lista separada.
  • El segundo paso es obtener el índice inicial y final de la respuesta en la lista context_ids. Obtenga los ID correspondientes a estas posiciones, use 'word2idx' para convertirlos en cadenas y compárelos con las etiquetas de inicio y final en la respuesta dada. Los ejemplos que no superen la prueba se agregarán a una lista.
  • La razón por la que algunos ejemplos fallan aquí es en gran parte porque no hay "'" o espacios antes y después de la respuesta en el contexto. Hay algunos marcadores de área que no se pueden capturar o simplemente no se limpiaron en el ejemplo.
def test_indices(df, idx2word):
    '''
    Performs the tests mentioned above. This method also gets the start and end of the answers
    with respect to the context_ids for each example.
    执行上面提到的测试。
    这个方法还获取与每个示例的context_id相关的答案的开始和结束。
    :param dataframe df: SQUAD df
    :returns
        list start_value_error: example idx where the start idx is not found in the start spans
                                of the text
                                示例idx,在文本的起始区域中找不到起始idx
        list end_value_error: example idx where the end idx is not found in the end spans
                              of the text
        list assert_error: examples that fail assertion errors. A majority are due to the above errors
        
    '''

    start_value_error = []
    end_value_error = []
    assert_error = []
    for index, row in df.iterrows():

        answer_tokens = [w.text for w in nlp(row['answer'], disable=['parser','tagger','ner'])]

        start_token = answer_tokens[0]
        end_token = answer_tokens[-1]
        
        context_span  = [(word.idx, word.idx + len(word.text)) 
                         for word in nlp(row['context'], disable=['parser','tagger','ner'])]
        # .idx 是按照字符进行分割,返回的是单词首字符的位置
        starts, ends = zip(*context_span)

        answer_start, answer_end = row['label']

        try:
            start_idx = starts.index(answer_start)
        except:
            start_value_error.append(index)
        try:
            end_idx  = ends.index(answer_end)
        except:
            end_value_error.append(index)

        try:
            assert idx2word[row['context_ids'][start_idx]] == answer_tokens[0]
            assert idx2word[row['context_ids'][end_idx]] == answer_tokens[-1]
        except:
            assert_error.append(index)


    return start_value_error, end_value_error, assert_error
def get_error_indices(df, idx2word):
    '''
    Gets error indices from the method above and returns a 
    set of those indices.
    从上面的方法中获取错误索引,并返回一组这些索引。
    '''
    
    start_value_error, end_value_error, assert_error = test_indices(df)
    err_idx = start_value_error + end_value_error + assert_error
    err_idx = set(err_idx)
    print(f"Error indices: {len(err_idx)}")
    
    return err_idx
def index_answer(row, idx2word):
    '''
    Takes in a row of the dataframe or one training example and
    returns a tuple of start and end positions of answer by calculating 
    spans.
    获取数据帧的一行或一个训练示例,并通过计算跨度返回答案的起始和结束位置的元组。
    '''
    
    context_span = [(word.idx, word.idx + len(word.text)) for word in nlp(row.context, disable=['parser','tagger','ner'])]
    starts, ends = zip(*context_span)
    
    answer_start, answer_end = row.label
    start_idx = starts.index(answer_start)
 
    end_idx  = ends.index(answer_end)
    
    ans_toks = [w.text for w in nlp(row.answer,disable=['parser','tagger','ner'])]
    ans_start = ans_toks[0]
    ans_end = ans_toks[-1]
    assert idx2word[row.context_ids[start_idx]] == ans_start
    assert idx2word[row.context_ids[end_idx]] == ans_end
    
    return [start_idx, end_idx]

para resumir

Se pega algo de código + traducción propia, la legibilidad no es buena, puedes descargarlo directamente en github.

Supongo que te gusta

Origin blog.csdn.net/qq_42388742/article/details/112104456
Recomendado
Clasificación