【论文复现】MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文复现



代码地址 :https://github.com/iioSnail/MDCSpell_pytorch



本文内容

本文为MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文的Pytorch实现。

论文地址: https://aclanthology.org/2022.findings-acl.98/

论文年份:2022

论文笔记:https://blog.csdn.net/zhaohongfei_358/article/details/126973451

论文大致内容:作者基于Transformer和BERT设计了一个多任务的网络来进行CSC(Chinese Spell Checking)任务(中文拼写纠错)。多任务分别是找出哪个字是错的和对错字进行纠正。

由于作者并没有公开代码,所以我就尝试自己实现一个,最终我的实验结果如下表:

Dataset Model D_Precision D_Recall D_F1 C_Prec C_Rec C_F1
SIGHAN 13 MDCSpell 89.1 78.3 83.4 87.5 76.8 81.8
SIGHAN 13 MDCSpell(复现) 80.2 79.9 80.0 77.2 76.9 77.1
SIGHAN 14 MDCSpell 70.2 68.8 69.5 69.0 67.7 68.3
SIGHAN 14 MDCSpell(复现) 82.8 66.6 73.8 79.9 64.3 71.2
SIGHAN 15 MDCSpell 80.8 80.6 80.7 78.4 78.2 78.3
SIGHAN 15 MDCSpell(复现) 86.7 76.1 81.1 72.5 82.7 77.3

这里是我训练了2个epoch的结果,与作者的结论相差不大。如果我增加训练次数的话,也许可以和作者的结果达到一致。

补充:这里有问题,论文使用的是Sentence-level,而我的是Character-level,所以我并没有复现出作者的效果。待后续有时间再尝试一下。

环境配置

try:
    import transformers
except:
    !pip install transformers
import os
import copy
import pickle

import torch
import transformers

from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
torch.__version__
'1.12.1+cu113'
transformers.__version__
'4.21.3'

全局变量

# 句子的长度,作者并没有说明。我这里就按经验取一个
max_length = 128
# 作者使用的batch_size
batch_size = 32
# epoch数,作者并没有具体说明,按经验取一个
epochs = 10

# 每${log_after_step}步,打印一次日志
log_after_step = 20

# 模型存放的位置。
model_path = './drive/MyDrive/models/MDCSpell/'
os.makedirs(model_path, exist_ok=True)
model_path = model_path + 'MDCSpell-model.pt'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
Device: cuda

模型构建

在这里插入图片描述


Correction Network 的数据流向如下:

1.将token序列 [CLS] 遇 到 逆 竟 [SEP] 送给Word Embedding模块进行embeddings,得到向量 { e C L S w , e 1 w , e 2 w , e 3 w , e 4 w , e S E P w } \{e_{CLS}^w, e_1^w, e_2^w, e_3^w, e_4^w,e_{SEP}^w\} { eCLSw,e1w,e2w,e3w,e4w,eSEPw}

个人认为此时的embedding仅仅是Word Embeding,并不包含Position Embedding和Segment Embedding。

2.之后将 { e C L S w , e 1 w , e 2 w , e 3 w , e 4 w , e S E P w } \{e_{CLS}^w, e_1^w, e_2^w, e_3^w, e_4^w,e_{SEP}^w\} { eCLSw,e1w,e2w,e3w,e4w,eSEPw}向量送入BERT,增加Position Embedding和Segment Embedding,得到 { e C , e 1 , e 2 , e 3 , e 4 , e S } \{e_C, e_1, e_2, e_3, e_4, e_S\} { eC,e1,e2,e3,e4,eS}

3.在BERT内部,会经历多层的TransformerEncoder,最终的得到输出向量 H c = { h C c , h 1 c , h 2 c , h 3 c , h 4 c , h S c } H^c=\{h_C^c, h_1^c, h_2^c, h_3^c, h_4^c, h_S^c\} Hc={ hCc,h1c,h2c,h3c,h4c,hSc}.

4.将BERT的输出 H c H^c Hc 和 隔壁Detection Network输出的 H d H^d Hd 进行融合,得到 H = H d + H c H = H^d+H^c H=Hd+Hc

融合时并不对[CLS][SEP]进行融合

5.将 H H H送给全连接层(Dense Layer)做最后的预测。


Correction Network模型细节

  1. BERT:作者使用的是具有12层Transformer Block的BERT-base版。
  2. Dense Layer:Dense Layer的输入通道为词向量维度,输出通道为词典大小。例如:词向量维度为768,词典大小为20000,则Dense Layer则为nn.Linear(768, 20000)
  3. Dense Layer的初始化:Dense Layer的权重使用的是Word Embedding的参数。因为word Embedding是将词index转成词向量,所以其参数刚好是Dense Layer的转置,即Word Embedding是nn.Linear(20000, 768),所以作者就是用Word Embedding的转置来初始化Dense Layer的参数。因为这样可以加速训练,且使模型变的稳定。

Detection Network的数据流向如下:

1.输入为使用BERT得到的word Embedding { e 1 w , e 2 w , e 3 w , e 4 w } \{e_1^w, e_2^w, e_3^w, e_4^w\} { e1w,e2w,e3w,e4w}。虽然图里并不包含[CLS][SEP]的词向量,但个人认为不需要对其特殊处理,因为最后的预测也用不到这两个token.

2.将 { e 1 w , e 2 w , e 3 w , e 4 w } \{e_1^w, e_2^w, e_3^w, e_4^w\} { e1w,e2w,e3w,e4w}增加Position Embedding信息,得到 { e 1 ′ , e 2 ′ , e 3 ′ , e 4 ′ } \{e_1', e_2', e_3', e_4'\} { e1,e2,e3,e4}

在论文中说Detection Network使用的是向量 { e 1 , e 2 , e 3 , e 4 } \{e_1, e_2, e_3, e_4\} { e1,e2,e3,e4},其是word embedding+position embedding+segment embedding。这与图上是矛盾的,这里以图为准了。

3.将 { e 1 ′ , e 2 ′ , e 3 ′ , e 4 ′ } \{e_1', e_2', e_3', e_4'\} { e1,e2,e3,e4}向量送入Transformer Block,得到输出向量 H d = { h 1 d , h 2 d , h 3 d , h 4 d } H^d=\{h_1^d, h_2^d, h_3^d, h_4^d\} Hd={ h1d,h2d,h3d,h4d}

4.一方面,将输出向量 H d H^d Hd送给隔壁的Correction Network进行融合;另一方面,将 H d H^d Hd送给后续的全连接层(Dense Layer)来判断哪个token是错误的.

Detection Network的细节:

  1. Transformer Block:Transformer Block是2层的TransformerEncoder。
  2. Transformer Block参数初始化:Transformer Block参数初始化使用的是BERT的权重。
  3. Dense Layer:Dense Layer的输入通道为词向量大小,输出通道为1。使用Sigmoid来判别该token为错字的概率。
class CorrectionNetwork(nn.Module):

    def __init__(self):
        super(CorrectionNetwork, self).__init__()
        # BERT分词器,作者并没提到自己使用的是哪个中文版的bert,我这里就使用一个比较常用的
        self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
        # BERT
        self.bert = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
        # BERT的word embedding,本质就是个nn.Embedding
        self.word_embedding_table = self.bert.get_input_embeddings()
        # 预测层。hidden_size是词向量的大小,len(self.tokenizer)是词典大小
        self.dense_layer = nn.Linear(self.bert.config.hidden_size, len(self.tokenizer))

    def forward(self, inputs, word_embeddings, detect_hidden_states):
        """
        Correction Network的前向传递
        :param inputs: inputs为tokenizer对中文文本的分词结果,
                       里面包含了token对一个的index,attention_mask等
        :param word_embeddings: 使用BERT的word_embedding对token进行embedding后的结果
        :param detect_hidden_states: Detection Network输出hidden state
        :return: Correction Network对个token的预测结果。
        """
        # 1. 使用bert进行前向传递
        bert_outputs = self.bert(token_type_ids=inputs['token_type_ids'],
                                 attention_mask=inputs['attention_mask'],
                                 inputs_embeds=word_embeddings)
        # 2. 将bert的hidden_state和Detection Network的hidden state进行融合。
        hidden_states = bert_outputs['last_hidden_state'] + detect_hidden_states
        # 3. 最终使用全连接层进行token预测
        return self.dense_layer(hidden_states)

    def get_inputs_and_word_embeddings(self, sequences, max_length=128):
        """
        对中文序列进行分词和word embeddings处理
        :param sequences: 中文文本序列。例如: ["鸡你太美", "哎呦,你干嘛!"]
        :param max_length: 文本的最大长度,不足则进行填充,超出进行裁剪。
        :return: tokenizer的输出和word embeddings.
        """
        inputs = self.tokenizer(sequences, padding='max_length', max_length=max_length, return_tensors='pt',
                                truncation=True).to(device)
        # 使用BERT的work embeddings对token进行embedding,这里得到的embedding并不包含position embedding和segment embedding
        word_embeddings = self.word_embedding_table(inputs['input_ids'])
        return inputs, word_embeddings
class DetectionNetwork(nn.Module):

    def __init__(self, position_embeddings, transformer_blocks, hidden_size):
        """
        :param position_embeddings: bert的position_embeddings,本质是一个nn.Embedding
        :param transformer: BERT的前两层transformer_block,其是一个ModuleList对象
        """
        super(DetectionNetwork, self).__init__()
        self.position_embeddings = position_embeddings
        self.transformer_blocks = transformer_blocks

        # 定义最后的预测层,预测哪个token是错误的
        self.dense_layer = nn.Sequential(
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )

    def forward(self, word_embeddings):
        # 获取token序列的长度,这里为128
        sequence_length = word_embeddings.size(1)
        # 生成position embedding
        position_embeddings = self.position_embeddings(torch.LongTensor(range(sequence_length)).to(device))
        # 融合work_embedding和position_embedding
        x = word_embeddings + position_embeddings
        # 将x一层一层的使用transformer encoder进行向后传递
        for transformer_layer in self.transformer_blocks:
            x = transformer_layer(x)[0]

        # 最终返回Detection Network输出的hidden states和预测结果
        hidden_states = x
        return hidden_states, self.dense_layer(hidden_states)
class MDCSpellModel(nn.Module):

    def __init__(self):
        super(MDCSpellModel, self).__init__()
        # 构造Correction Network
        self.correction_network = CorrectionNetwork()
        self._init_correction_dense_layer()

        # 构造Detection Network
        # position embedding使用BERT的
        position_embeddings = self.correction_network.bert.embeddings.position_embeddings
        # 作者在论文中提到的,Detection Network的Transformer使用BERT的权重
        # 所以我这里直接克隆BERT的前两层Transformer来完成这个动作
        transformer = copy.deepcopy(self.correction_network.bert.encoder.layer[:2])
        # 提取BERT的词向量大小
        hidden_size = self.correction_network.bert.config.hidden_size

        # 构造Detection Network
        self.detection_network = DetectionNetwork(position_embeddings, transformer, hidden_size)

    def forward(self, sequences, max_length=128):
        # 先获取word embedding,Correction Network和Detection Network都要用
        inputs, word_embeddings = self.correction_network.get_inputs_and_word_embeddings(sequences, max_length)
        # Detection Network进行前向传递,获取输出的Hidden State和预测结果
        hidden_states, detection_outputs = self.detection_network(word_embeddings)
        # Correction Network进行前向传递,获取其预测结果
        correction_outputs = self.correction_network(inputs, word_embeddings, hidden_states)
        # 返回Correction Network 和 Detection Network 的预测结果。
        # 在计算损失时`[PAD]`token不需要参与计算,所以这里将`[PAD]`部分全都变为0
        return correction_outputs, detection_outputs.squeeze(2) * inputs['attention_mask']

    def _init_correction_dense_layer(self):
        """
        原论文中提到,使用Word Embedding的weight来对Correction Network进行初始化
        """
        self.correction_network.dense_layer.weight.data = self.correction_network.word_embedding_table.weight.data

定义好模型后,我们来简单的尝试一下:

model = MDCSpellModel().to(device)
correction_outputs, detection_outputs = model(["鸡你太美", "哎呦,你干嘛!"])
print("correction_outputs shape:", correction_outputs.size())
print("detection_outputs shape:", detection_outputs.size())
correction_outputs shape: torch.Size([2, 128, 21128])
detection_outputs shape: torch.Size([2, 128])

损失函数

Correction Network和Detection Network使用的都是Cross Entropy。之后进行相加即可:

L = λ L c + ( 1 − λ ) L d L = \lambda L^c + (1-\lambda) L^d L=λLc+(1λ)Ld

其中 λ ∈ [ 0 , 1 ] \lambda \in [0,1] λ[0,1] 。作者通过实验得出 λ = 0.85 \lambda=0.85 λ=0.85 时效果最好。

class MDCSpellLoss(nn.Module):

    def __init__(self, coefficient=0.85):
        super(MDCSpellLoss, self).__init__()
        # 定义Correction Network的Loss函数
        self.correction_criterion = nn.CrossEntropyLoss(ignore_index=0)
        # 定义Detection Network的Loss函数,因为是二分类,所以用Binary Cross Entropy
        self.detection_criterion = nn.BCELoss()
        # 权重系数
        self.coefficient = coefficient

    def forward(self, correction_outputs, correction_targets, detection_outputs, detection_targets):
        """
        :param correction_outputs: Correction Network的输出,Shape为(batch_size, sequence_length, hidden_size)
        :param correction_targets: Correction Network的标签,Shape为(batch_size, sequence_length)
        :param detection_outputs: Detection Network的输出,Shape为(batch_size, sequence_length)
        :param detection_targets: Detection Network的标签,Shape为(batch_size, sequence_length)
        :return:
        """
        # 计算Correction Network的loss,因为Shape维度为3,所以要把batch_size和sequence_length进行合并才能计算
        correction_loss = self.correction_criterion(correction_outputs.view(-1, correction_outputs.size(2)),
                                                    correction_targets.view(-1))
        # 计算Detection Network的loss
        detection_loss = self.detection_criterion(detection_outputs, detection_targets)
        # 对两个loss进行加权平均
        return self.coefficient * correction_loss + (1 - self.coefficient) * detection_loss

模型训练

作者的训练方式:

  1. 第一步,首先使用 Wang271K(自己造的假数据) 数据集进行训练。batch size为32, learning rate为2e-5

  2. 第二步,使用SIGHAN训练集进行fine-tune。 batch size为32,learning rate为1e-5

作者并没有提到使用的是什么Optimizer,但看这个学习率,应该是Adam。

在第一步,作者说的是使用了几乎3M个,但作者只提到过Wang271K这个数据集,我猜可能作者看错了,这个是0.3M条数据,而不是3M。

作者首先使用了Wang271K数据集进行对模型进行训练,然后又使用SIGHAN训练集对模型进行fine-tune。这里我就不进行fine-tune了,直接进行训练。我这里使用的是 ReaLiSe论文 处理好的数据集,其就是Wang271K和SIGHAN。

百度网盘链接 :https://pan.baidu.com/s/1x67LPiYAjLKhO1_2CI6aOA?pwd=skda

下载好直接解压即可。

构造Dataset

class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        src = self.train_data[index]['src']
        tgt = self.train_data[index]['tgt']
        return src, tgt

    def __len__(self):
        return len(self.train_data)
train_data = CSCDataset()
train_data.__getitem__(0)
('纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一美元以上。')

构造Dataloader

tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    src_tokens = tokenizer(src, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']
    tgt_tokens = tokenizer(tgt, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']

    correction_targets = tgt_tokens
    detection_targets = (src_tokens != tgt_tokens).float()
    return src, correction_targets, detection_targets, src_tokens  # src_tokens在计算Correction的精准率时要用到
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)

训练

criterion = MDCSpellLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
start_epoch = 0  # 从哪个epoch开始
total_step = 0  # 一共更新了多少次参数
# 恢复之前的训练
if os.path.exists(model_path):
    if not torch.cuda.is_available():
        checkpoint = torch.load(model_path, map_location='cpu')
    else:
        checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch']
    total_step = checkpoint['total_step']
    print("恢复训练,epoch:", start_epoch)
恢复训练,epoch: 2
model = model.to(device)
model = model.train()

训练这里代码量看起来很大,但实际大多都是计算recall和precision的代码。这里对于Detection的recall和precision的计算使用的是Detection Network的预测结果。

total_loss = 0.  # 记录loss

d_recall_numerator = 0  # Detection的Recall的分子
d_recall_denominator = 0  # Detection的Recall的分母
d_precision_numerator = 0  # Detection的precision的分子
d_precision_denominator = 0  # Detection的precision的分母
c_recall_numerator = 0  # Correction的Recall的分子
c_recall_denominator = 0  # Correction的Recall的分母
c_precision_numerator = 0  # Correction的precision的分子
c_precision_denominator = 0  # Correction的precision的分母

for epoch in range(start_epoch, epochs):

    step = 0

    for sequences, correction_targets, detection_targets, correction_inputs in train_loader:
        correction_targets, detection_targets = correction_targets.to(device), detection_targets.to(device)
        correction_inputs = correction_inputs.to(device)
        correction_outputs, detection_outputs = model(sequences)
        loss = criterion(correction_outputs, correction_targets, detection_outputs, detection_targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1
        total_step += 1

        total_loss += loss.detach().item()

        # 计算Detection的recall和precision指标
        # 大于0.5,认为是错误token,反之为正确token
        d_predicts = detection_outputs >= 0.5
        # 计算错误token中被网络正确预测到的数量
        d_recall_numerator += d_predicts[detection_targets == 1].sum().item()
        # 计算错误token的数量
        d_recall_denominator += (detection_targets == 1).sum().item()
        # 计算网络预测的错误token的数量
        d_precision_denominator += d_predicts.sum().item()
        # 计算网络预测的错误token中,有多少是真错误的token
        d_precision_numerator += (detection_targets[d_predicts == 1]).sum().item()

        # 计算Correction的recall和precision
        # 将输出映射成index,即将correction_outputs的Shape由(32, 128, 21128)变为(32,128)
        correction_outputs = correction_outputs.argmax(2)
        # 对于填充、[CLS]和[SEP]这三个token不校验
        correction_outputs[(correction_targets == 0) | (correction_targets == 101) | (correction_targets == 102)] = 0
        # correction_targets的[CLS]和[SEP]也要变为0
        correction_targets[(correction_targets == 101) | (correction_targets == 102)] = 0
        # Correction的预测结果,其中True表示预测正确,False表示预测错误或无需预测
        c_predicts = correction_outputs == correction_targets
        # 计算错误token中被网络正确纠正的token数量
        c_recall_numerator += c_predicts[detection_targets == 1].sum().item()
        # 计算错误token的数量
        c_recall_denominator += (detection_targets == 1).sum().item()
        # 计算网络纠正token的数量
        correction_inputs[(correction_inputs == 101) | (correction_inputs == 102)] = 0
        c_precision_denominator += (correction_outputs != correction_inputs).sum().item()
        # 计算在网络纠正的这些token中,有多少是真正被纠正对的
        c_precision_numerator += c_predicts[correction_outputs != correction_inputs].sum().item()

        if total_step % log_after_step == 0:
            loss = total_loss / log_after_step
            d_recall = d_recall_numerator / (d_recall_denominator + 1e-9)
            d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)
            c_recall = c_recall_numerator / (c_recall_denominator + 1e-9)
            c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)

            print("Epoch {}, "
                  "Step {}/{}, "
                  "Total Step {}, "
                  "loss {:.5f}, "
                  "detection recall {:.4f}, "
                  "detection precision {:.4f}, "
                  "correction recall {:.4f}, "
                  "correction precision {:.4f}".format(epoch, step, len(train_loader), total_step,
                                                       loss,
                                                       d_recall,
                                                       d_precision,
                                                       c_recall,
                                                       c_precision))

            total_loss = 0.
            total_correct = 0
            total_num = 0
            d_recall_numerator = 0
            d_recall_denominator = 0
            d_precision_numerator = 0
            d_precision_denominator = 0
            c_recall_numerator = 0
            c_recall_denominator = 0
            c_precision_numerator = 0
            c_precision_denominator = 0

    torch.save({
    
    
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch + 1,
        'total_step': total_step,
    }, model_path)
。。。
Epoch 2, Step 15/8882, Total Step 8900, loss 0.02403, detection recall 0.4118, detection precision 0.8247, correction recall 0.8192, correction precision 0.9485
Epoch 2, Step 35/8882, Total Step 8920, loss 0.03479, detection recall 0.3658, detection precision 0.8055, correction recall 0.8029, correction precision 0.9125
。。。

模型评估

模型评估使用了SIGHAN 2013,2014,2015三个数据集对模型进行评估。对于Detection的Precision和Recall的评估,使用的是Correction Network的结果,这和训练阶段有所不同,这是因为Detection Network只是帮助Correction Network训练的,其结果在使用时不具备参考价值。

model = model.eval()
def evaluation(test_data):
    d_recall_numerator = 0  # Detection的Recall的分子
    d_recall_denominator = 0  # Detection的Recall的分母
    d_precision_numerator = 0  # Detection的precision的分子
    d_precision_denominator = 0  # Detection的precision的分母
    c_recall_numerator = 0  # Correction的Recall的分子
    c_recall_denominator = 0  # Correction的Recall的分母
    c_precision_numerator = 0  # Correction的precision的分子
    c_precision_denominator = 0  # Correction的precision的分母

    prograss = tqdm(range(len(test_data)))
    for i in prograss:
        src, tgt = test_data[i]['src'], test_data[i]['tgt']

        src_tokens = tokenizer(src, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]
        tgt_tokens = tokenizer(tgt, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]

        # 正常情况下,src和tgt的长度应该是一致的
        if len(src_tokens) != len(tgt_tokens):
            print("第%d条数据异常" % i)
            continue

        correction_outputs, _ = model(src)
        predict_tokens = correction_outputs[0][1:len(src_tokens) + 1].argmax(1).detach().cpu()

        # 计算错误token的数量
        d_recall_denominator += (src_tokens != tgt_tokens).sum().item()
        # 计算在这些错误token,有多少网络也认为它是错误的
        d_recall_numerator += (predict_tokens != src_tokens)[src_tokens != tgt_tokens].sum().item()
        # 计算网络找出的错误token的数量
        d_precision_denominator += (predict_tokens != src_tokens).sum().item()
        # 计算在网络找出的这些错误token中,有多少是真正错误的
        d_precision_numerator += (src_tokens != tgt_tokens)[predict_tokens != src_tokens].sum().item()
        # 计算Detection的recall、precision和f1-score
        d_recall = d_recall_numerator / (d_recall_denominator + 1e-9)
        d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)
        d_f1_score = 2 * (d_recall * d_precision) / (d_recall + d_precision + 1e-9)

        # 计算错误token的数量
        c_recall_denominator += (src_tokens != tgt_tokens).sum().item()
        # 计算在这些错误token中,有多少网络预测对了
        c_recall_numerator += (predict_tokens == tgt_tokens)[src_tokens != tgt_tokens].sum().item()
        # 计算网络找出的错误token的数量
        c_precision_denominator += (predict_tokens != src_tokens).sum().item()
        # 计算网络找出的错误token中,有多少是正确修正的
        c_precision_numerator += (predict_tokens == tgt_tokens)[predict_tokens != src_tokens].sum().item()

        # 计算Correction的recall、precision和f1-score
        c_recall = c_recall_numerator / (c_recall_denominator + 1e-9)
        c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)
        c_f1_score = 2 * (c_recall * c_precision) / (c_recall + c_precision + 1e-9)

        prograss.set_postfix({
    
    
            'd_recall': d_recall,
            'd_precision': d_precision,
            'd_f1_score': d_f1_score,
            'c_recall': c_recall,
            'c_precision': c_precision,
            'c_f1_score': c_f1_score,
        })
with open("data/test.sighan13.pkl", mode='br') as f:
    sighan13 = pickle.load(f)
evaluation(sighan13)
100%|██████████| 1000/1000 [00:11<00:00, 90.12it/s, d_recall=0.799, d_precision=0.802, d_f1_score=0.8, c_recall=0.769, c_precision=0.772, c_f1_score=0.771]  
with open("data/test.sighan14.pkl", mode='br') as f:
    sighan14 = pickle.load(f)
evaluation(sighan14)
100%|██████████| 1062/1062 [00:12<00:00, 85.48it/s, d_recall=0.666, d_precision=0.828, d_f1_score=0.738, c_recall=0.643, c_precision=0.799, c_f1_score=0.712]
with open("data/test.sighan15.pkl", mode='br') as f:
    sighan15 = pickle.load(f)
evaluation(sighan15)
100%|██████████| 1100/1100 [00:11<00:00, 92.04it/s, d_recall=0.761, d_precision=0.867, d_f1_score=0.811, c_recall=0.725, c_precision=0.827, c_f1_score=0.773] 

模型使用

最后,我们来真正的使用一下该模型,看下效果:

def predict(text):
    sequences = [text]
    correction_outputs, _ = model(sequences)
    tokens = correction_outputs[0][1:len(text) + 1].argmax(1)
    return ''.join(tokenizer.convert_ids_to_tokens(tokens))
predict("今天早上我吃了以个火聋果")
'今天早上我吃了一个火聋果'
predict("我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳RAP蓝球")
'我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳ra##p蓝球[SEP]'

虽然在数据上模型表现还不错,但在真正使用场景上,效果还是不够好。中文文本纠错果然是一个比较难的任务 T_T !





参考文献

MDCSpell论文: https://aclanthology.org/2022.findings-acl.98/

MDCSpell论文笔记:https://blog.csdn.net/zhaohongfei_358/article/details/126973451

猜你喜欢

转载自blog.csdn.net/zhaohongfei_358/article/details/127035600