Bert模型实现

1.BertEmbeddings 类

Bert的输入由三部分组成,分别是words_embedding,sposition_embeddings , token_embeddings

Bert中的embedding模块做的工作就是:生成三个embedding数组并相加,而后进行LayerNorm和dropout操作后返回

BertEmbeddings这个类继承了nn.Module这个类,一共有五个成员变量,分别是words_embedding,position_embeddings , token_embeddings,LayerNorm和dropout

输入的每个词在词表中都有一个位置input_ids,然后在words_embedding词表中找到每个词是什么,然后找到position_embeddings , token_embeddings,分别生成三个embedding数组,然后,将三个数组对应位置的元素相加。LayerNorm进行了这个层的归一化,dropout加入了过拟合等操作。

最后embeddings = words_embeddings + position_embeddings + token_type_embeddings,进行LayerNorm和dropout操作后返回

源码如下:

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

2.BertEncoder 类

Bert的编码部分是由一个一个的Encoder堆叠而成,每一个encoder都叫做一个bertlayer,可以说BertEncoder类是由BertLayer类组成的。 而BertLayer类由三部分组成:BertAttention类,BertIntermediate类,BertOut类

在BertEncoder 类中,forward函数先通过BertAttention得到注意力输出,之后通过处理输出,将一部分通过intermediate和output,再与另一部分求和得到最终的输出。

__init__ 方法:将要堆叠的 layer使用一个for循环放到一个列表ModuleList中。

forward 这个方法主要具体完成前向计算的一个堆叠工作,第一层的返回结果将会作为下一层的参数输入进去,最后把结果放入一个all_encoder_layers的列表中进行返回。

BertAttention类:它是用于计算注意力得分的类

BertIntermediate类:作用是选择一系列的激活函数,进行非线性变换

BertOut类:使用了linear,layernorm和dropout。在最后计算时,做了linear的映射,输出的值进行dropout,最后接着是Add & Norm层,Add残差的方式对不同的输出相加,防止梯度消失。

源码如下:

class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        layer = BertLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
        all_encoder_layers = []
        for layer_module in self.layer:
            hidden_states = layer_module(hidden_states, attention_mask)
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
        return all_encoder_layers

3.BertPooler 类

BertPooler 类是 Bert 的输出模块 , 利用了一个 l i near 线形层加一个 Tanh() 的激 活函数, 用来池化 BertEncoder 的输出。

重点解释这句代码:  first_token_tensor = hidden_states[:, 0]

上面Bertlayer层最终输出的是 all_encoder_layers,而BertPooler模块的全连接层只简单选择了最上面一层的输出结果

__init__ 方法:确定全连接层网络的隐含层节点个数,并且定义了网络的激活函数(Tanh() 函数)

forward 对输入进行了一个全连接,然后进行非线性激活

源码如下: 

class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

4. BertModel 类

BertModel类调用了BertEmbeddings、BertEncoder、BertPooler三个类,分别处理词嵌入、网络定义和池化。

forward:首先将input_ids等输入进来的数据进行embedding表示,再将embedding表示输入encoder进行处理,然后encoder网络得到输出encoder_outputs,最后将encoder_outputs通过self.pooler池化,返回encoded_layers和pooled_output。

源码如下:

class BertModel(PreTrainedBertModel):
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output

猜你喜欢

转载自blog.csdn.net/Minor0218/article/details/126051749