【论文阅读】Spelling Error Correction with Soft-Masked BERT

论文内容

发表时间:2020年05月

论文地址:https://arxiv.org/abs/2005.07421

代码地址(非作者实现): https://github.com/quantum00549/SoftMaskedBert

摘要(Abstract)

使用Soft-Masked BERT完成中文拼写纠错(Chinses Spell Checking, CSC)任务,并且该方法也适用于其他语言。

1. 介绍(Introduction)

Soft-Masked BERT = 双向GRU(Bi-GRU) + BERT

其中Bi-GRU负责预测哪个地方有错误,BERT负责对错误进行修正。

2. 方法(Our Approach)

2.1 问题和思路(Problem and Motivation)

作者说原始的BERT是mask了 “15%” 的词训练的,这不足以让BERT学会如何找出句子中的错误,所以要使用新的方法。

2.2 模型(Model)

在这里插入图片描述

该模型分为三部分:

  1. Detection Network:负责预测句子中每个字错误的概率
  2. Correct Network:负责将错字纠正成正确的字。
  3. Soft Masking:Detection Network和Correction Network之间的桥梁,负责根据Detection Network的输出对原始句子embedding进行mask。

2.3 Detection Network

在这里插入图片描述

输入: embedding后的characters序列。embedding方式和BERT一样,包括word embedding,position embedding和segment embedding.

网络架构:Bi-GRU -> 全连接层(Linear) -> Sigmoid

输出:每个character为错字的概率,越接近1表示越有可能是错的。

2.3 Soft Masking

在这里插入图片描述

Soft Masking模块就是对Input进行mask,方式就是加权,公式为:

e i ′ = p i ⋅ e m a s k + ( 1 − p i ) ⋅ e i e_i^{\prime}=p_i \cdot e_{m a s k}+\left(1-p_i\right) \cdot e_i ei=piemask+(1pi)ei

  • e i ′ e'_i ei :第 i i i个character进行mask后的结果。

  • p i p_i pi :第 i i i个character为错字的概率, p i ∈ [ 0 , 1 ] p_i \in [0,1] pi[0,1]

  • e m a s k e_{mask} emask :mask embeding。具体是什么原文中并没有说明。github上quantum00549的论文复现使用的是在这里插入图片描述

  • e i e_i ei :第 i i i个character的词向量。

2.4 Correction Network

在这里插入图片描述

输入:soft-masking后的input。

网络架构:BERT+全连接层(Linear)+Softmax

输出:将词修正后的结果。

注意:在BERT和Linear之间,有一个残差连接,即将input和bert的output进行相加。用公式表示则为:

h i ′ = h i c + e i h_i^{\prime}=h_i^c+e_i hi=hic+ei

  • h i ′ h'_i hi :Linear的输入的第 i i i个character。
  • h i c h^c_i hic:Bert的输出的第 i i i个character的隐状态。
  • e i e_i ei:第 i i i个character的词向量。

2.5 损失函数(Learning)

Detection Network和Correction Network损失函数使用的都是CrossEntropy,用公式表示为:

L d = − ∑ i = 1 n log ⁡ P d ( g i ∣ X ) L c = − ∑ i = 1 n log ⁡ P c ( y i ∣ X ) \begin{aligned} \mathcal{L}_d &=-\sum_{i=1}^n \log P_d\left(g_i \mid X\right) \\ \mathcal{L}_c &=-\sum_{i=1}^n \log P_c\left(y_i \mid X\right) \end{aligned} LdLc=i=1nlogPd(giX)=i=1nlogPc(yiX)

  • L d \mathcal{L}_d Ld:Detection Network的损失
  • L c \mathcal{L}_c Lc:Correction Network的损失

联合起来为:

L = λ ⋅ L c + ( 1 − λ ) ⋅ L d \mathcal{L}=\lambda \cdot \mathcal{L}_c+(1-\lambda) \cdot \mathcal{L}_d L=λLc+(1λ)Ld

其中 λ \lambda λ [ 0 , 1 ] [0,1] [0,1] 的超参数。

3. 实验结果(Experimental Result)

3.1 数据集(Datasets)

benchmark: SIGHAN

训练集:自己造的,使用confusion table的方式。具体为将一个句子中15%的字替换成与其相同发音的其他常见字。在所有样本中,有80%的句子按上述方式处理,剩下20%则是直接随机替换成任意文字。

3.2 Baselines(略)

3.3 实验设置(Experiment Setting)

优化器(optimizer):Adam

学习策略(Learning Scheduler): 无

学习率(Learning Rate):2e-5

The size of hidden unit in Bi-GRU is 256

batch size: 320

作者还使用了500w个训练样本和SIGHAN中训练样本对BERT进行了fine-tune.

3.4 实验结果(Main Result)

在这里插入图片描述

  • Acc:Accuracy,准确率
  • Pre:Precision,精准率
  • Rec:Recall,召回率
  • F1:F1 Score

猜你喜欢

转载自blog.csdn.net/zhaohongfei_358/article/details/126675655