数据预处理(二):将数据处理成bert模型的输入数据格式

使用上节处理后的数据,数据存放在match_data文件下,将数据再处理成bert模型的输入数据格式

  • 导入需要的包
import torch
import os
import pickle as pkl
from tqdm import tqdm
from torch.utils.data import dataset
class TextMatchDataset(dataset.Dataset):
    def __init__(self, config, path):
        self.config = config
        self.path = path
        self.inference = False
        self.max_len = self.config.pad_size
        self.contents = self.load_dataset_match(config)

    def load_dataset_match(self, config):
        if "test" in self.path:
            self.inference = True
        if self.config.token_type:
            pad, cls, sep = '[PAD]', '[CLS]', '[SEP]'
        else:
            pad, cls, sep = '<pad>', '<s>', '<s/>'

        contens = []
        lenth_count = []
        file_stream = open(self.path, 'r', encoding="utf-8")
        for line in tqdm(file_stream.readlines()):
            lin = line.strip()
            if not lin:
                continue
            if len(lin.split("\t")) != 3:
                print(line)
                continue
            source, target, label = lin.split('\t')
            token_id_full = []
            mask_full = []
            # 对超长序列进行截断
            seq_source = config.tokenizer.tokenize(source[:(self.max_len - 2)])
            seq_target = config.tokenizer.tokenize(target[:(self.max_len - 1)])
            # 分别在句子收尾拼接特殊符号
            seq_token = [cls] + seq_source + [sep] + seq_target + [sep]
            # 段标记
            seq_segment = [0] * (len(seq_source) + 2) + [1] * (len(seq_target) + 1)
            # id化标记
            seq_idx = self.config.tokenizer.convert_tokens_to_ids(seq_token)
            # 根据max_len与seq_idx的长度产生填充序列
            padding = [0] * ((self.max_len * 2) - len(seq_idx))
            # seg_mask
            seq_mask = [1] * len(seq_idx) + padding
            # 对seq拼接填充序列
            seq_idx = seq_idx + padding
            # seq_segment
            seq_segment = seq_segment + padding
            
            # print(seq_idx)
            # print(seq_mask)
            # print(seq_segment)
            # print(len(seq_idx))
            # print(len(seq_mask))
            # print(len(seq_segment))
            
            assert len(seq_idx) == self.max_len * 2
            assert len(seq_mask) == self.max_len * 2
            assert len(seq_segment) == self.max_len * 2

            token_id_full.append(seq_idx)
            token_id_full.append(seq_mask)
            token_id_full.append(seq_segment)

            if self.inference:
                token_id_full.append(label)
            else:
                token_id_full.append(int(label))

            contens.append(token_id_full)

        return contens

    def __getitem__(self, index):
        elements = self.contents[index]
        seq_idx = torch.LongTensor(elements[0])
        seq_mask = torch.LongTensor(elements[1])
        seq_segment = torch.LongTensor(elements[2])
        if not self.inference:
            label = torch.LongTensor([elements[3]])
        else:
            label = [elements[3]]
        return (seq_idx, seq_mask, seq_segment), label

    def __len__(self):
        return len(self.contents)
from param import Param
if __name__ == '__main__':
    param = Param(base_path="./match_data", model_name="SimBERT_A")
    train_data = TextMatchDataset(param, param.dev_path)
    (token, mask, segment), label = train_data[0]
    print(train_data[4300])
    print(len(token))
    print(len(mask))
    print(len(segment))
  • Param.py
import os.path as osp
# from util import mkdir_if_no_dir
import os
from transformers import BertTokenizer, ElectraTokenizer, AutoTokenizer

def mkdir_if_no_dir(path):
    """创建不存在的文件夹"""
    if not os.path.exists(path):
        os.mkdir(path)

class Param:
    def __init__(self, base_path, model_name):
        if "A" in model_name:
            self.train_path = osp.join(base_path, 'train_A.txt')           # 训练集
            self.dev_path = osp.join(base_path, 'valid_A.txt')               # 验证集
            self.test_path = osp.join(base_path, 'test_A.txt')             # 测试集
            self.result_path = osp.join(base_path, "predict_A.csv")
        else:
            self.train_path = osp.join(base_path, 'train_B.txt')           # 训练集
            self.dev_path = osp.join(base_path, 'valid_B.txt')               # 验证集
            self.test_path = osp.join(base_path, 'test_B.txt')             # 测试集
            self.result_path = osp.join(base_path, "predict_B.csv")
        print([self.train_path, self.dev_path, self.test_path, self.result_path])
        mkdir_if_no_dir(osp.join(base_path, "saved_dict"))
        mkdir_if_no_dir(osp.join(base_path, "log"))
        self.save_path = osp.join(osp.join(base_path, 'saved_dict'), model_name + '.pt')  # 模型训练结果
        self.log_path = osp.join(osp.join(base_path, "log"), model_name)                  # 日志保存路径
        self.vocab_path = osp.join(base_path, "vocab.pkl")
        self.class_path = osp.join(base_path, "class.txt")
        self.vocab = {
    
    }
        self.device = None
        self.token_type = True
        self.model_name = "BERT"
        self.warmup_steps = 1000
        self.t_total = 100000
        self.class_list = {
    
    }
        with open(self.class_path, "r", encoding="utf-8") as fr:
            idx = 0
            for line in fr:
                line = line.strip("\n")
                self.class_list[line] = idx
                idx += 1
        self.class_list_verse = {
    
    v: k for k, v in self.class_list.items()}
        self.num_epochs = 5                                        # epoch数
        self.batch_size = 32                                       # mini-batch大小
        self.pad_size = 256                                          # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-5                                   # 学习率
        self.require_improvement = 10000000                            # 若超过1000batch效果还没提升,则提前结束训练
        self.multi_gpu = True
        self.device_ids = [0, 1]
        self.full_fine_tune = True
        self.use_adamW = True
        self.input_language = "multi"     # ["eng", "original", "multi"]
        self.MAX_VOCAB_SIZE = 20000
        self.min_vocab_freq = 1

        if "BERT" in model_name:
            print("Load BERT Tokenizer")
            self.bert_path = "bert-base-chinese"
            self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        else:
            print("Load BERT Tokenizer")
            self.bert_path = "bert-base-chinese"
            self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)

  • match_data下的class.txt文件
0
1

猜你喜欢

转载自blog.csdn.net/weixin_40605573/article/details/115922343
今日推荐