BERT 预训练PTM数据样本简单构建认识

参考:https://www.cnblogs.com/little-horse/p/14622047.html

BERT 预训练PTM任务

1.MLM(masked lm):带mask的语言模型

 2.NSP(next sentence prediction):是否为下一句话

在这里插入图片描述

BERT输入主要包括三部分:

1.字的token embeddings

2.句子的segment embedding  

3.句子位置的position embeddings

在这里插入图片描述

代码处理生成样本数据案例

1、词表生成(其实是字表)

import re
import math
import numpy as np
import random

text = (
    '随后,文章为中美关系未来发展提出了5点建议。\n'
    '第一,美国应恢复“和平队”等在华奖学金项目。\n'
    '文章称,这些项目在过去几十年帮助美国了解中国,却被特朗普政府因意图孤立中国而取消。\n'
    '第二,美国应停止污名化孔子学院。文章说,孔子学院只是文化中心和教育机构,性质类似于德国的歌德学院和英国文化协会。\n'
    '第三,美国应该允许此前被特朗普政府驱逐出境的中国记者回到美国。同时文章建议中国也允许美国记者入境。\n'
    '第四,美国应取消限制入境的做法。\n'
    '第五,美方应邀请中国重新开放中国驻休斯顿领事馆。\n'
    '文章称,如此一来,中国也将重新允许美国驻成都领事馆开放。\n'
    '文章最后表示,尽管这些都是微小的举动,但对建立中美互信很有意义,能够为解决更加棘手的问题铺设道路。'
)
sentences = re.sub("[.。,“”,!?\\-]", '', text.lower()).split('\n') # 过滤特殊符号
word_list = list("".join(sentences))  ## 包含的所有字的列表

# 以下是词典的构建
word2idx = {'[pad]':0, '[cls]':1, '[sep]':2, '[unk]':3, '[mask]':4}
for i, w in enumerate(word_list):
    word2idx[w] = i + 5

idx2word = {i : w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)

## 每句话用对应上词表索引
token_list = list()
for sentence in sentences:
    arr = [word2idx[s] for s in list(sentence)]
    token_list.append(arr)

在这里插入图片描述
2、输入样本数据构建(按照训练任务生成对应样本格式)

def build_data():
    batch = []
    positive = negative = 0
    while (positive != (batch_size/2)) or (negative != (batch_size/2)):
        # 随机选择句子的index,作为A,B句
        tokens_a_index, tokens_b_index = random.randrange(len(sentences)), random.randrange(len(sentences))
         
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
         
        # 拼接A句与B句,格式为:[cls] + A句 + [sep] + B句 + [sep]
        input_ids = [word2idx['[cls]']] + tokens_a + [word2idx['[sep]']] + tokens_b + [word2idx['[sep]']]
         
        # 这里是为了表示两个不同的句子,如A句用0表示,B句用1表示
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
     
        # mask lm,15%随机选择token
        n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15))) # 句子中的15%的token
         
        # 15%随机选择的token,去除特殊符号[cls]与[sep]
        cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word2idx['[cls]'] and token != word2idx['[sep]']] # 候选masked 位置
         
        random.shuffle(cand_maked_pos)
         
        # 存储被mask的词的位置与token
        masked_tokens, masked_pos = [], []
         
        # 对input_ids进行mask, 80%的时间用于mask替换,10%的时间随机替换,10%的时间不替换。
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos) # mask的位置
            masked_tokens.append(input_ids[pos]) # mask的token
            if random.random() < 0.8: # 80%的时间用mask替换
                input_ids[pos] = word2idx['[mask]']
            elif random.random() > 0.9: # 10%的时间随机替换
                index = random.randint(0, vocab_size - 1)
                while index < 5:
                    index = random.randint(0, vocab_size - 1) # 不包含几个特征符号
                input_ids[pos] = index
             
        # 进行padding,input_ids与segment_ids补齐到最大长度max_len
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
        #print('n_pad:{}'.format(n_pad))
         
        # 不同句子中的mask长度不同,所以需要进行相同长度补齐
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)
            
            
        # 构建nsp正负样本,每一batch里正负样本数相等
        if ((tokens_a_index + 1) == tokens_b_index) and (positive < (batch_size / 2)):
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # isnext
            positive += 1
        elif ((tokens_a_index + 1) != tokens_b_index) and (negative < (batch_size / 2)):
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # notnext
            negative += 1
    return batch



'''
构建的数据里除了input_ids,segment_ids外,还有masked_tokens,masked_pos被mask掉的字和其位置(用于bert训练时用),isNext是否为下一句。
'''
batch = build_data()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_42357472/article/details/119038196
今日推荐