1.BertEmbeddings 类
Bert的输入由三部分组成,分别是words_embedding,sposition_embeddings , token_embeddings
Bert中的embedding模块做的工作就是:生成三个embedding数组并相加,而后进行LayerNorm和dropout操作后返回
BertEmbeddings这个类继承了nn.Module这个类,一共有五个成员变量,分别是words_embedding,position_embeddings , token_embeddings,LayerNorm和dropout
输入的每个词在词表中都有一个位置input_ids,然后在words_embedding词表中找到每个词是什么,然后找到position_embeddings , token_embeddings,分别生成三个embedding数组,然后,将三个数组对应位置的元素相加。LayerNorm进行了这个层的归一化,dropout加入了过拟合等操作。
最后embeddings = words_embeddings + position_embeddings + token_type_embeddings,进行LayerNorm和dropout操作后返回
源码如下:
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
2.BertEncoder 类
Bert的编码部分是由一个一个的Encoder堆叠而成,每一个encoder都叫做一个bertlayer,可以说BertEncoder类是由BertLayer类组成的。 而BertLayer类由三部分组成:BertAttention类,BertIntermediate类,BertOut类
在BertEncoder 类中,forward函数先通过BertAttention得到注意力输出,之后通过处理输出,将一部分通过intermediate和output,再与另一部分求和得到最终的输出。
__init__ 方法:将要堆叠的 layer使用一个for循环放到一个列表ModuleList中。
forward :这个方法主要具体完成前向计算的一个堆叠工作,第一层的返回结果将会作为下一层的参数输入进去,最后把结果放入一个all_encoder_layers的列表中进行返回。
BertAttention类:它是用于计算注意力得分的类
BertIntermediate类:作用是选择一系列的激活函数,进行非线性变换
BertOut类:使用了linear,layernorm和dropout。在最后计算时,做了linear的映射,输出的值进行dropout,最后接着是Add & Norm层,Add残差的方式对不同的输出相加,防止梯度消失。
源码如下:
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
3.BertPooler 类
BertPooler 类是 Bert 的输出模块 , 利用了一个 l i near 线形层加一个 Tanh() 的激 活函数, 用来池化 BertEncoder 的输出。
重点解释这句代码: first_token_tensor = hidden_states[:, 0]
上面Bertlayer层最终输出的是 all_encoder_layers,而BertPooler模块的全连接层只简单选择了最上面一层的输出结果
__init__ 方法:确定全连接层网络的隐含层节点个数,并且定义了网络的激活函数(Tanh() 函数)
forward :对输入进行了一个全连接,然后进行非线性激活
源码如下:
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
4. BertModel 类
BertModel类调用了BertEmbeddings、BertEncoder、BertPooler三个类,分别处理词嵌入、网络定义和池化。
forward:首先将input_ids等输入进来的数据进行embedding表示,再将embedding表示输入encoder进行处理,然后encoder网络得到输出encoder_outputs,最后将encoder_outputs通过self.pooler池化,返回encoded_layers和pooled_output。
源码如下:
class BertModel(PreTrainedBertModel):
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output