SQuAD 数据预处理(3)

SQuAD 数据预处理 (3)

提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加
例如:第一章 Python 机器学习入门之pandas的使用



前言

1. 这里将模型数据保存到pickle文件中,方便下次调用。 2. 创建dataloader类,返回模型训练期间所需要的数据 3. 得到glove字典,创建glove和数据集词汇的权重矩阵和

一、转存数据

python的pickle模块,可以将对象以文件的形式存放在磁盘上。
pickle.dump(obj, file[, protocol]):进行序列化
pickle.load(file):进行反序列化

转存数据到pickle files

确保下一次我们可以直接访问预处理过的dataframe文件

import pickle
# 保存训练集和验证集df数据
train_df.to_pickle('qanettrain.pkl')
valid_df.to_pickle('qanetvalid.pkl')

with open('qanetw2id.pickle','wb') as handle:
    pickle.dump(word2idx, handle)

with open('qanetc2id.pickle','wb') as handle:
    pickle.dump(char2idx, handle)

从pickle文件中阅读数据

这样只需要预处理一次,因为某些预处理功能可能要有几分钟,因此,pickling可以节约大量的数据数据清洗时间

import pickle
# 读取字word2idx和char2idx文件
with open('qanetw2id.pickle', 'rb') as handle:
    word2idx = pickle.load(handle)
with open('qanetc2id.pickle','rb') as handle:
    char2idx = pickle.load(handle)

train_df = pd.read_pickle('qanettrain.pkl')
valid_df = pd.read_pickle('qanetvalid.pkl')

idx2word = {
    
    v:k for k,v in word2idx.items()}
idx2char = {
    
    v:k for k,v in word2idx.items()}

word_vocab = list(word2idx)
char_vocab = list(char2idx)

二、创建dataloader

这个类负责批处理、创建字符向量,返回训练期间需要的所有事情:
padded_context: 每个batch 填充过的context
padded_question: 每个batch 填充过的question
char_ctx & ques_ctx: character-level ids for context and question 字符水平的ids
label: start and end index wrt context_ids 开始和结束的位置 context_ids
context_text, answer_text: used while validation to calculate metrics
ids: question_ids for evaluation

dataloader代码

class SquadDataset:
    '''
    - Creates batches dynamically by padding to the length of largest example
      in a given batch.
      通过填充到给定批处理中最大示例的长度来动态创建批处理。
    - Calulates character vectors for contexts and question.
      为context和question创建字符向量
    - Returns tensors for training.
      返回训练用的tensors
    '''
    def __init__(self, data, batch_size):
        '''
        data: dataframe
        batch_size: int
        '''
        self.batch_size = batch_size
        data = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        self.data = data
        # 此时self.data中的数据是[batch1,batch2,...] 每个batch中包含batch_size条数据
        
        
    def __len__(self):
        # 如果一个类表现得像一个list,要获取有多少个元素,就得用 len() 函数。
        # 要让 len() 函数工作正常,类必须提供一个特殊方法__len__(),它返回元素的个数。
        # 只要正确实现了__len__()方法,就可以用len()函数返回长度
        return len(self.data)
    
    def make_char_vector(self, max_sent_len, sentence, max_word_len=16):
        # 单词长度小于16的,对每个句子中 每个单词 的每个字符创建长度为16的 tensor
        char_vec = torch.zeros(max_sent_len, max_word_len).type(torch.LongTensor)
        # 创建max_sent_len * max_word_len 的张量矩阵,类型LongTensor
        for i, word in enumerate(nlp(sentence, disable=['parser','tagger','ner'])):
            for j, ch in enumerate(word.text):
                if j == max_word_len:
                    break
                char_vec[i][j] = char2idx.get(ch, 0)    # 有返回idx,无返回0
        # 这个句子第i个单词,第j个字符的tensor,padding部分为全0
        return char_vec     
    
    def get_span(self, text):
        # len(w.text)是单词的长度 .idx是按照字符进行分割返回的是字符的位置

        text = nlp(text, disable=['parser','tagger','ner'])
        span = [(w.idx, w.idx+len(w.text)) for w in text]
        
        # [(,),,...(,)] 单词的开始和结束位置
        return span


    def __iter__(self):
        '''
        创建迭代器
        yields解释: https://blog.csdn.net/mieleizhi0522/article/details/82142856/
        Creates batches of data and yields them.
        
        Each yield comprises of:
        :padded_context: padded tensor of contexts for each batch 
         每个batch 填充过的context
        :padded_question: padded tensor of questions for each batch 
         每个batch 填充过的question
        :char_ctx & ques_ctx: character-level ids for context and question 字符水平的ids
        :label: start and end index wrt context_ids  开始和结束的位置 context_ids
        :context_text,answer_text: used while validation to calculate metrics
        :ids: question_ids for evaluation
        '''
        
        for batch in self.data:
            # 遍历每个batch
            spans = []    
            ctx_text = []    # 记录每个context
            answer_text = []
            
             
            for ctx in batch.context:
                # 遍历每个batch中的context字段,将每条context加载到列表ctx_text
                ctx_text.append(ctx)
                spans.append(self.get_span(ctx))
            
            for ans in batch.answer:
                answer_text.append(ans)
                
            max_context_len = max([len(ctx) for ctx in batch.context_ids])
            # context_ids, 返回这个批次中最长的一句话的长度
            padded_context = torch.LongTensor(len(batch), max_context_len).fill_(1)    # 1填充
            
            for i, ctx in enumerate(batch.context_ids):
                padded_context[i, :len(ctx)] = torch.LongTensor(ctx)
            # 将数据类型转化为longtensor()
            
            max_word_ctx = 16
          
            char_ctx = torch.zeros(len(batch), max_context_len, max_word_ctx).type(torch.LongTensor)
            for i, context in enumerate(batch.context):
                char_ctx[i] = self.make_char_vector(max_context_len, context)
            # 批次*句长*单词 最后一维每个位置代表一个字符,创建矩阵
            
            max_question_len = max([len(ques) for ques in batch.question_ids])
            padded_question = torch.LongTensor(len(batch), max_question_len).fill_(1)
            
            for i, ques in enumerate(batch.question_ids):
                padded_question[i, :len(ques)] = torch.LongTensor(ques)
                
            max_word_ques = 16
            
            char_ques = torch.zeros(len(batch), max_question_len, max_word_ques).type(torch.LongTensor)
            for i, question in enumerate(batch.question):
                char_ques[i] = self.make_char_vector(max_question_len, question)
            # 与context进行相同处理
              
            label = torch.LongTensor(list(batch.label_idx))
            ids = list(batch.id)
            
            yield (padded_context, padded_question, char_ctx, char_ques, label, ctx_text, answer_text, ids)
 

查看各种张量的形状

# create dataloaders

train_dataset = SquadDataset(train_df,16)
valid_dataset = SquadDataset(valid_df,16)

# 查看loader返回的各种张量的性状
a = next(iter(train_dataset))
for i in range(len(a)):
    try:
        print(a[i].shape)
    except AttributeError:
        print(len(a[i]))
-------------------------------------------------------------------
out:
torch.Size([16, 253])
torch.Size([16, 16])
torch.Size([16, 253, 16])
torch.Size([16, 16, 16])
torch.Size([16, 2])
16
16
16

三、加载glove与权重矩阵

解析glove词向量文件

def get_glove_dict():
    '''
    Parses the glove word vectors text file and returns a dictionary with the words as
    keys and their respective pretrained word vectors as values.
    解析glove词向量文件,并返回一个字典,单词做键,预训练的词向量作为值
    '''
    glove_dict = {
    
    }
    with open("./glove.840B.300d/glove.840B.300d.txt", "r", encoding="utf-8") as f:
        for line in f:
            values = line.split(' ')
            word = values[0]
            vector = np.asarray(values[1:], dtype="float32")
            glove_dict[word] = vector

    f.close()
    
    return glove_dict

创建权重矩阵

def create_weights_matrix(glove_dict):
    '''
    Creates a weight matrix of the words that are common in the GloVe vocab and
    the dataset's vocab. Initializes OOV words with a zero vector.
    创建GloVe和数据集的词汇表中常见单词的权重矩阵,用零向量初始化OOV单词
    '''
    weights_matrix = np.zeros((len(word_vocab), 300))
    words_found = 0
    for i, word in enumerate(word_vocab):
        try:
            weights_matrix[i] = glove_dict[word]
            words_found += 1
        except:
            pass

    return weights_matrix, words_found

weights_matrix, words_found = create_weights_matrix(glove_dict)
print("Words found in the GloVe vocab: " ,words_found)
----------------------------------------------------------------------------
out:
Words found in the GloVe vocab:  91194
# save the weight matrix for future loading.
# This matrix is the nn.Embedding's weight matrix.
# 保存权重矩阵 这个权重矩阵是nn.Embedding的权重矩阵
np.save('qanetglove_vt.npy', weights_matrix)

总结

上午处理和加载完所有数据,开始整理QANet模型

猜你喜欢

转载自blog.csdn.net/qq_42388742/article/details/112177743
今日推荐