ZEN- based N-gram of Chinese Encoder (from paper to source)

Background

Depth study of the encoder is based on large-scale unlabeled data, but these encoder is complete use of all the information corpus, which is unproven. These pre-training model is similar to Bert used is the smallest unit of text - words. But the Chinese are not the smallest unit of word, Chinese semantics and N-gram have a great relationship.
Shortcomings of the current model

1. 基于word masking,encoder只能学习到已有的词和句的信息
2. 基于mask的方法在pre-train和fine-tune阶段mismatch。因为预训练过程中遮盖存在但是fine-tune阶段遮盖不存在。
3. 错误的分词或实体识别会影响到encoder的通用能力

Therefore, the paper proposes ZEN- based N-gram of Chinese Encoder
ZEN has the following characteristics

1. 引入N-gram编码方式,方便模型识别出可能的字的组合
2. 虽然引入了N-gram但是encoder的输出还是按照Bert那样逐字输出不会影响下有任务。

ZEN pre-training process based on Chinese Wikipedia training, fine-tuning it based on other downstream of the Chinese mission.

Let's look at specific ZEN

ZEN

N-Gram

N-gram extraction

N-gram extraction in two steps, the first step is based on N-gram frequency generating vocabulary Lexicon, please note that these might contain N-gram relation, for example, there exist in the prior corpus Guangdong, Hong Kong and Macau . The second step is to generate N-gram matrix of the training data according to the table, as shown in FIG.
Here Insert Picture Description
N-gram k_c * k_n Matrix is a matrix, wherein k_c words contained in sentences, K_n N-gram is the number of sentences to be extracted. m_ij indicates whether the i belonging to the j-th word N-gram
Here Insert Picture Description
generating N-Gram matrix here is very simple, examples.utils_sequence_level_tasks code location, the function of convert_examples_to_features. This function is mainly batch rokenize after conversion into an input word id, and the label is processed, while the N-Gram encoded. Other processes We will not say much, look at the main logic N-Gram matrix of this section.

# ----------- code for ngram BEGIN-----------
ngram_matches = []
#  Filter the word segment from 2 to 7 to check whether there is a word
for p in range(2, 8):
    for q in range(0, len(tokens) - p + 1):
        character_segment = tokens[q:q + p]
        # j is the starting position of the word
        # i is the length of the current word
        character_segment = tuple(character_segment)
        if character_segment in ngram_dict.ngram_to_id_dict:
            ngram_index = ngram_dict.ngram_to_id_dict[character_segment]
            ngram_matches.append([ngram_index, q, p, character_segment])

shuffle(ngram_matches)
# max_word_in_seq_proportion = max_word_in_seq
max_word_in_seq_proportion = math.ceil((len(tokens) / max_seq_length) * ngram_dict.max_ngram_in_seq)
if len(ngram_matches) > max_word_in_seq_proportion:
    ngram_matches = ngram_matches[:max_word_in_seq_proportion]
ngram_ids = [ngram[0] for ngram in ngram_matches]
ngram_positions = [ngram[1] for ngram in ngram_matches]
ngram_lengths = [ngram[2] for ngram in ngram_matches]
ngram_tuples = [ngram[3] for ngram in ngram_matches]
ngram_seg_ids = [0 if position < (len(tokens_a) + 2) else 1 for position in ngram_positions]

import numpy as np
ngram_mask_array = np.zeros(ngram_dict.max_ngram_in_seq, dtype=np.bool)
ngram_mask_array[:len(ngram_ids)] = 1

# record the masked positions
ngram_positions_matrix = np.zeros(shape=(max_seq_length, ngram_dict.max_ngram_in_seq), dtype=np.int32)
for i in range(len(ngram_ids)):
    ngram_positions_matrix[ngram_positions[i]:ngram_positions[i] + ngram_lengths[i], i] = 1.0

# Zero-pad up to the max word in seq length.
padding = [0] * (ngram_dict.max_ngram_in_seq - len(ngram_ids))
ngram_ids += padding
ngram_lengths += padding
ngram_seg_ids += padding
# ----------- code for ngram END-----------

Note ngram_dict is generated in advance, every word we first traverse each combination, generate all possible ngram, and record their length and the starting position. ngram_positions_matrix is ​​what we need N-Gram matrix, he is a max_seq_length * max_ngram_in_seq matrix, which max_seq_length is the length of the input word, max_ngram_in_seq is the number of a sentence in a maximum of N-Gram combination, the default is 128, then traverse the assignment. Note that when a word is his mask off N-gram are no longer taken into account.

N-gram encoding

N-gram Encoder structure as shown below, the multilayer structure of the article transformer used to encode the N-gram, because the order of N-gram therefore no need to consider position encoding. N-gram encoder for enhancing the efficiency of the model is very much affected, why le, because N-gram encoder can learn some important phrases in a sentence, so as to enhance the efficiency of the model. Inside this N-gram embedding entered to be understood Word embedding,
Here Insert Picture Description
code N-Gram Embedding encoding and Word Embedding also similar. The following are ZEN Word Emebedding and N-Gram Emebedding is generated.

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
        
class BertWordEmbeddings(nn.Module):
    """Construct the embeddings from ngram, position and token_type embeddings.
    """

    def __init__(self, config):
        super(BertWordEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.word_size, config.hidden_size, padding_idx=0)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

N-gram for pretraining

Model structure shown below.
Here Insert Picture Description
ZEN model will have word and its related N-gram is encoded, how the combination of this, that is the matrix added.
Here Insert Picture Description

  • l v_li first layer output is character_encoder the i-th character in the hidden output

  • u_lik 是第l层和第i个character有关的第k个N-gram。需要注意的是这里一个字可以被包含到多个N-gram中,例如 粤港澳大湾区港澳
    那么对于第l层encoder这种增强可以表示为
    Here Insert Picture Description

  • V_l是这一层的embedding matrix

  • U_l是character-N-gram相关矩阵

  • M是N-gram matrix
    需要注意的是如果这个字被masked掉了,那么这个字的N-gram就不会被加进去了。

ZEN Encoder的代码如下,其中hidden_states加上了N-Gram经过attention的结果。

class ZenEncoder(nn.Module):
    def __init__(self, config, output_attentions=False, keep_multihead_output=False):
        super(ZenEncoder, self).__init__()
        self.output_attentions = output_attentions
        layer = BertLayer(config, output_attentions=output_attentions,
                          keep_multihead_output=keep_multihead_output)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
        self.word_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_word_layers)])
        self.num_hidden_word_layers = config.num_hidden_word_layers

    def forward(self, hidden_states, ngram_hidden_states, ngram_position_matrix, attention_mask,
                ngram_attention_mask,
                output_all_encoded_layers=True, head_mask=None):
        # Need to check what is the attention masking doing here
        all_encoder_layers = []
        all_attentions = []
        num_hidden_ngram_layers = self.num_hidden_word_layers
        for i, layer_module in enumerate(self.layer):
            hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])
            if i < num_hidden_ngram_layers:
                ngram_hidden_states = self.word_layers[i](ngram_hidden_states, ngram_attention_mask, head_mask[i])
                if self.output_attentions:
                    ngram_attentions, ngram_hidden_states = ngram_hidden_states
            if self.output_attentions:
                attentions, hidden_states = hidden_states
                all_attentions.append(attentions)
            hidden_states += torch.bmm(ngram_position_matrix.float(), ngram_hidden_states.float())
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
        if self.output_attentions:
            return all_attentions, all_encoder_layers
        return all_encoder_layers

实验结果

实验设置

论文使用了中文wiki作为语料,并去除了标点符号,进行了简体转化,对英文字母统一转为小写的数据清洗。

N-gram词典是根据训练语料,对N-gram按照词频排序并设置阈值,频率低于阈值的N-gram将会被剔除。最终的N-gram包含17.9万~6.4万之间。N-gram embedding是随机初始化的,模型结构和Bert结构相同,采用12层12个muti-head attention结构,hidden size大小为768。预训练也和Bert相同采用MLM和NSP任务。

实验效果

模型的实验效果如下图所示,其实R表示模型参数随机加载,P表示模型参数根据谷歌的Bert模型初始化,B表示用的是Bert Base,L表示Bert Large。可以看出ZEN在多个模型上取得了当前比较好的效果。
Here Insert Picture Description

相关分析

文中还进行了一些分析。

小规模语料上进行预训练

The current pre-training models are mostly carried out experiments on large data sets, it is difficult to collect for some areas of large data sets, so this paper to extract the Wikipedia corpus 1/10 the size of the pre-training, to take a random initialization parameter model. ZEN can be seen on a small scale datasets effect is slightly better than Bert. N-gram should be because of the embedding has been enhanced, which means that ZEN scene of small-scale datasets better than Bert.
Here Insert Picture Description

convergence speed

The following figure shows the ZEN different training epoch of performance in the CWS (Chinese word segmentation) and SA (Sentiment analysis) tasks. As can be seen the same effect epochZEN better than Bert, and Bert converge faster than the ZEN colleagues.
Here Insert Picture Description

N-gram Threshold

Article extraction threshold value of the N-gram frequency analysis we found that 10 to 20 in thresholds best time. At the same time the number of papers on the most used N-gram analysis have also been found that with the increase in the number of N-gram model effect has been part of the upgrade.
Here Insert Picture Description

Thermodynamic diagram analysis

Thesis of N-gram encoder also FIG thermodynamic analysis, shown below, is two words in each N-gram weight in 1 to 7 layers. As can be seen, "meaningful" N-gram weight occupies an important high heavier than "meaningless" the N-gram weights, such as "improve" and "Boston" than "increase" and "Boston" is an important right high. This table will focus on semantic ZEN N-gram, select the appropriate phrase. We found that the longer phrase to get the weight is relatively large at a relatively high level, which also indicates that these long phrases to understand the model has more important impact statement.
Here Insert Picture Description

Relevant information

  1. ZEN Articles
  2. ZEN achieve
Published 115 original articles · won praise 25 · Views 140,000 +

Guess you like

Origin blog.csdn.net/lion19930924/article/details/104204391