高复用Bert模型文本分类代码(二)模型部分

「这是我参与11月更文挑战的第8天,活动详情查看:2021最后一次更文挑战

高复用Bert模型文本分类代码

上期链接:高复用Bert模型文本分类代码(一)数据读取

源码解读

源码中模型被单独保存在model文件夹下,先来看一下module.py,里面放置有简单的全连接神经网络模型,作为分类器。

分类器的网络结构很简单,仅由两层构成。

  • dropout层
  • Linear层
# module.py
import torch.nn as nn

# 分类器
class IntentClassifier(nn.Module):
    def __init__(self, input_dim, num_labels, dropout_rate=0.):
        super(IntentClassifier, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, num_labels)

    def forward(self, x):
        x = self.dropout(x)
        return self.linear(x)
复制代码

接下来看重点看bert模型代码

import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, BertConfig
from torchcrf import CRF
from .module import IntentClassifier


class ClsBERT(BertPreTrainedModel):
    def __init__(self, config, args, label_lst):
        super(ClsBERT, self).__init__(config)
        self.args = args
        self.num_labels = len(label_lst)
        self.bert = BertModel(config=config)  # Load pretrained bert

        self.classifier = IntentClassifier(config.hidden_size, self.num_labels, args.dropout_rate)


    def forward(self, input_ids, attention_mask, token_type_ids, label_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]
        pooled_output = outputs[1]  # [CLS]

        logits = self.classifier(pooled_output)

        outputs = ((logits),) + outputs[2:]  # add hidden states and attention if they are here

        # 1. Intent Softmax
        if label_ids is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), label_ids.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), label_ids.view(-1))

            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)
复制代码

重点在forward部分

        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0] # sequence_output = outputs.last_hidden_state
        pooled_output = outputs[1]   # [CLS]  /  pooled_output = outputs.pooler_output 
复制代码

一般使用transformers做bert finetune时,bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)会返回两个值,一个是sequence_output,其shape大小为(batch_size,bert_hidden_size),另一个是pooled_output,这里的pooler_output指的是输出序列最后一个隐层,即CLS标签,其shape大小为(batch_size,bert_hidden_size)

  • 可以通过 outputs[0]或者outputs.last_hidden_state取得sequence_output向量。
  • 可以通过 outputs[1]或者outputs.pooler_output 取得pooled_output向量。

一般对于分类任务取bert的最后层输出做平均池化接入线性层,代码中可以直接用outputs.pooler_output作为linear的输入,也可以使用outputs.last_hidden_state.mean(dim=1)作为linear的输入,自己测试后者要更好一点。

改进bert输出

我们知道bert模型有12层transformer层组成,如果我们要取出其中某一层的向量,或者做向量拼接该如何做呢?

我们查看BertModel(BertPreTrainedModel)的官方文档,里面对返回值outputs的解释如下:

Outputs: Tuple comprising various elements depending on the configuration (config) and inputs:

last_hidden_state: torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)
Sequence of hidden-states at the output of the last layer of the model.

pooler_output: torch.FloatTensor of shape (batch_size, hidden_size)
Last layer hidden-state of the first token of the sequence (classification token)further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually not a good summary of the semantic content of the input, you're often better with averaging or pooling the sequence of hidden-states for the whole input sequence.

hidden_states: (optional, returned when config.output_hidden_states=True),list of torch.FloatTensor (one for the output of each layer + the output of the embeddings)of shape (batch_size, sequence_length, hidden_size):
Hidden-states of the model at the output of each layer plus the initial embedding outputs.

attentions: (optional, returned when config.output_attentions=True),list of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length):Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

根据官方文档我们可以看到有一个hidden_states层会返回所有layer的向量,其形状为(batch_size, sequence_length, hidden_size),但要输出这个list需要在初始化bert时配置config.output_hidden_states=True,才会返回 hidden_states

我们重新修改一下代码,尝试取出bert的倒数第三层transformer的输出向量

class ClsBERT(BertPreTrainedModel):
    def __init__(self, config, args, label_lst):
        super(ClsBERT, self).__init__(config)
        self.args = args
        self.num_labels = len(label_lst)
        self.bert = BertModel(config=config)  # Load pretrained bert

        self.classifier = IntentClassifier(config.hidden_size, self.num_labels, args.dropout_rate)


    def forward(self, input_ids, attention_mask, token_type_ids, label_ids):
        """添加 output_hidden_states = True
        """
        outputs = self.bert(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids,output_hidden_states = True) 
        """修改pooled_outputs
        hidden_states[-3] 表示倒数第三层输出
        mean(dim=1) 表示 平均池化输出向量 得到(batch_size,hidden_layers)
        """
        pooled_output = outputs.hidden_states[-3].mean(dim=1)
			
        logits = self.classifier(pooled_output)

        outputs = ((logits),) + outputs[2:]  # add hidden states and attention if they are here

        # 1. Intent Softmax
        if label_ids is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), label_ids.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), label_ids.view(-1))

            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)
复制代码

既然可以取出每一层的向量,那么我们也可以完成不同层向量拼接,修改代码如下:
先创建一个空的tensor :torch.empty(0, dtype=torch.long).to(self.device) 使用循环和cat拼接向量
注意:记得修改对应的linear的输入层数大小,要和pooled_output的hidden_size保持一致

class ClsBERT(BertPreTrainedModel):
    def __init__(self, config, args, label_lst):
        super(ClsBERT, self).__init__(config)
        self.args = args
        self.num_labels = len(label_lst)
        self.bert = BertModel(config=config)  # Load pretrained bert

        self.classifier = IntentClassifier(config.hidden_size, self.num_labels, args.dropout_rate)


    def forward(self, input_ids, attention_mask, token_type_ids, label_ids):
        """添加 output_hidden_states = True
        """
        outputs = self.bert(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids,output_hidden_states = True) 
        """修改pooled_outputs
        torch.empty(0, dtype=torch.long).to(self.device)
        hidden_states[-3] 表示倒数第三层输出
        mean(dim=1) 表示 平均池化输出向量 得到(batch_size,hidden_layers)
        """
        pooled_output = torch.empty(0, dtype=torch.long).to(self.device)
        for layer in outputs.hidden_states[self.concatnum:]:
            pooled_output = torch.cat((pooled_output, layer.mean(dim=1)), dim=1)
		"""修改 end
        """
        
        logits = self.classifier(pooled_output)

        outputs = ((logits),) + outputs[2:]  # add hidden states and attention if they are here

        # 1. Intent Softmax
        if label_ids is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), label_ids.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), label_ids.view(-1))

            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)
复制代码

到此为止,模型代码的讲解已经结束,并且我们还对bert模型的输出形式进行了讨论和改进
关于优化器、学习率、损失函数,将在下一篇文章模型训练代码中进行讲解。
NLP萌新,才疏学浅,有错误或者不完善的地方,请批评指正!!

猜你喜欢

转载自juejin.im/post/7028045671798145038