fastbert做文本分类

模型结构

FastBERT的创新点很容易理解,就是在每层Transformer后都去预测样本标签,如果某样本预测结果的置信度很高,就不用继续计算了。论文把这个逻辑称为样本自适应机制(Sample-wise adaptive mechanism),就是自适应调整每个样本的计算量,容易的样本通过一两层就可以预测出来,较难的样本则需要走完全程。

那么问题来了,用什么去预测中间层的结果呢?作者的解决方案是给每层后面接一个分类器,毕竟分类器比Transformer需要的成本小多了:

注:FLOPs (floating point operations)是Tensorflow中提供的浮点数计算量统计

于是模型的整体结构就呼之欲出了:

作者将原BERT模型称为主干(Backbone),每个分类器称为分支(Branch)。

要注意的是,这里的分支Classifier都是最后一层的分类器蒸馏来的,作者将这称为自蒸馏(Self-distillation)。就是在预训练和精调阶段都只更新主干参数,精调完后freeze主干参数,用分支分类器(图中的student)蒸馏主干分类器(图中的teacher)的概率分布。

之所以叫自蒸馏,是因为之前的蒸馏都是用两个模型去做,一个模型学习另一个模型的知识,而FastBERT是自己(分支)蒸馏自己(主干)的知识。值得注意的是,蒸馏时需要freeze主干部分,保证pretrain和finetune阶段学习的知识不被影响,仅用brach 来尽可能的拟合teacher的分布。

那为什么不直接用标注数据训分支分类器呢?因为直接训效果不好呗(摊手~下面是作者在消融实验给出的结果:

可以看到,非蒸馏的结果没有蒸馏要好。个人认为是合理的,因为这两种方式在精调阶段的目标不一样。非自蒸馏是在精调阶段训练所有分类器,目标函数有所改变,迫使前几层编码器抽取更多的任务feature。但BERT强大的能力与网络深度的相关性很大,所以过早地判断不一定准确,致使效果下降。

同时,使用自蒸馏还有一点重要的好处,就是不再依赖于标注数据。蒸馏的效果可以通过源源不断的无标签数据来提升。

模型训练与推理

了解模型结构之后,训练与推理也就很自然了。只比普通的BERT模型多了自蒸馏这个步骤:

Pre-training:同BERT系模型是一样的,网上那么多开源的模型也可以随意拿来~Fine-tuning for Backbone:主干精调,也就是给BERT最后一层加上分类器,用任务数据训练,这里也用不到分支分类器,可以尽情地优化Self-distillation for branch:分支自蒸馏,用无标签任务数据就可以,将主干分类器预测的概率分布蒸馏给分支分类器。这里使用KL散度衡量分布距离,loss是所有分支分类器与主干分类器的KL散度之和Adaptive inference:自适应推理,及根据分支分类器的结果对样本进行层层过滤,简单的直接给结果,困难的继续预测。这里作者定义了新的不确定性指标,用预测结果的熵来衡量,熵越大则不确定性越大:

[公式]
效果

对于每层分类结果,作者用“Speed”代表不确定性的阈值,和推理速度是正比关系。因为阈值越小 => 不确定性越小 => 过滤的样本越少 => 推理速度越慢。

模型最终在12个数据集(6个中文的和6个英文的)上的表现还是很好的:

可以看到,在Speed=0.2时速度可以提升1-10倍,且精度下降全部在0.11个点之内,甚至部分任务上还有细微提升。相比之下HuggingFace的DistillBERT的波动就比较剧烈了,6层模型速度只提升2倍,但精度下降最高会达到7个点。

代码

模型和代码下载链接:https://pan.baidu.com/s/1uzAm-M6dRaR2X-jFQbknbg
提取码:go67

# -*- encoding:utf-8 -*-
"""
  This script provides an exmaple to the fine-tuning and self-distillation 
  peocess of the FastBERT.
"""
import os, sys
import torch
import json
import random
import argparse
import collections
import torch.nn as nn
from uer.utils.vocab import Vocab
from uer.utils.constants import *
from uer.utils.tokenizer import * 
from uer.model_builder import build_model
from uer.utils.optimizers import *
from uer.utils.config import load_hyperparam
from uer.utils.seed import set_seed
from uer.model_saver import save_model
from uer.model_loader import load_model
from uer.layers.multi_headed_attn import MultiHeadedAttention
import numpy as np
import time
from thop import profile


torch.set_num_threads(1)


def normal_shannon_entropy(p, labels_num):
    entropy = torch.distributions.Categorical(probs=p).entropy()
    normal = -np.log(1.0/labels_num)
    return entropy / normal


class Classifier(nn.Module):

    def __init__(self, args, input_size, labels_num):
        super(Classifier, self).__init__()
        self.input_size = input_size
        self.cla_hidden_size = 128
        self.cla_heads_num = 2
        self.labels_num = labels_num
        self.pooling = args.pooling
        self.output_layer_0 = nn.Linear(input_size, self.cla_hidden_size)
        self.self_atten = MultiHeadedAttention(self.cla_hidden_size, self.cla_heads_num, args.dropout)
        self.output_layer_1 = nn.Linear(self.cla_hidden_size, self.cla_hidden_size)
        self.output_layer_2 = nn.Linear(self.cla_hidden_size, labels_num)
    
    def forward(self, hidden, mask):

        hidden = torch.tanh(self.output_layer_0(hidden))
        hidden = self.self_atten(hidden, hidden, hidden, mask)
        
        if self.pooling == "mean":
            hidden = torch.mean(hidden, dim=-1)
        elif self.pooling == "max":
            hidden = torch.max(hidden, dim=1)[0]
        elif self.pooling == "last":
            hidden = hidden[:, -1, :]
        else:
            hidden = hidden[:, 0, :]

        output_1 = torch.tanh(self.output_layer_1(hidden))
        logits = self.output_layer_2(output_1)
        return logits


class FastBertClassifier(nn.Module):
    def __init__(self, args, model):
        super(FastBertClassifier, self).__init__()

        self.embedding = model.embedding
        self.encoder = model.encoder
        self.labels_num = args.labels_num
        self.classifiers = nn.ModuleList([
                Classifier(args, args.hidden_size, self.labels_num) \
                for i in range(self.encoder.layers_num)
             ])
        self.softmax = nn.LogSoftmax(dim=-1)
        self.criterion = nn.NLLLoss()
        self.soft_criterion = nn.KLDivLoss(reduction='batchmean')
        self.threshold = args.speed

    def forward(self, src, label, mask, fast=True):
        """
        Args:
            src: [batch_size x seq_length]
            label: [batch_size]
            mask: [batch_size x seq_length]
        """
        # Embedding.
        emb = self.embedding(src, mask)

        # Encoder.
        seq_length = emb.size(1)
        mask = (mask > 0). \
                unsqueeze(1). \
                repeat(1, seq_length, 1). \
                unsqueeze(1)
        mask = mask.float()
        mask = (1.0 - mask) * -10000.0
         
        if self.training:

            if label is not None:

                # training main part of the model
                hidden = emb
                for i in range(self.encoder.layers_num):
                    hidden = self.encoder.transformer[i](hidden, mask)
                logits = self.classifiers[-1](hidden, mask)
                loss = self.criterion(self.softmax(logits.view(-1, self.labels_num)), label.view(-1))
                return loss, logits
            else:
                # distillate the subclassifiers
                loss, hidden, hidden_list = 0, emb, []
                with torch.no_grad():
                    for i in range(self.encoder.layers_num):
                        hidden = self.encoder.transformer[i](hidden, mask)
                        hidden_list.append(hidden)
                    teacher_logits = self.classifiers[-1](hidden_list[-1], mask).view(-1, self.labels_num) 
                teacher_probs = nn.functional.softmax(teacher_logits, dim=1)
                loss = 0
                for i in range(self.encoder.layers_num - 1):
                    student_logits = self.classifiers[i](hidden_list[i], mask).view(-1, self.labels_num)
                    loss += self.soft_criterion(self.softmax(student_logits), teacher_probs) 
                return loss, teacher_logits

        else:
            # inference 
            if fast:
                # fast mode 
                hidden = emb  # (batch_size, seq_len, emb_size)
                batch_size = hidden.size(0)
                logits = torch.zeros(batch_size, self.labels_num, dtype=hidden.dtype, device=hidden.device)
                abs_diff_idxs = torch.arange(0, batch_size, dtype=torch.long, device=hidden.device)
                for i in range(self.encoder.layers_num):
                    
                    hidden = self.encoder.transformer[i](hidden, mask)

                    logits_this_layer = self.classifiers[i](hidden, mask)  # (batch_size, labels_num)
                    logits[abs_diff_idxs] = logits_this_layer

                    # filter easy sample
                    abs_diff_idxs, rel_diff_idxs = self._difficult_samples_idxs(abs_diff_idxs, logits_this_layer) 
                    hidden = hidden[rel_diff_idxs, :, :]
                    mask = mask[rel_diff_idxs, :, :]
                    
                    if len(abs_diff_idxs) == 0:
                        break

                return None, logits
            else:
                # normal mode
                hidden = emb
                for i in range(self.encoder.layers_num):
                    hidden = self.encoder.transformer[i](hidden, mask)
                logits = self.classifiers[-1](hidden, mask)
                return None, logits
                    
    def _difficult_samples_idxs(self, idxs, logits):
        # logits: (batch_size, labels_num)
        probs = nn.Softmax(dim=1)(logits)
        entropys = normal_shannon_entropy(probs, self.labels_num)
        # torch.nonzero() is very time-consuming on GPU 
        # Please see https://github.com/pytorch/pytorch/issues/14848
        # If anyone can optimize this operation, please contact me, thank you!
        rel_diff_idxs = (entropys > self.threshold).nonzero().view(-1)
        abs_diff_idxs = torch.tensor([idxs[i] for i in rel_diff_idxs], device=logits.device)
        return abs_diff_idxs, rel_diff_idxs
        
        
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path", default="./models/Chinese_base_model.bin", type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path", default="./models/fastbert.bin", type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path", type=str, required=False,default="./models/google_zh_vocab.txt",
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path", type=str, required=False, default="./datasets/douban_book_review/train.tsv",
                        help="Path of the trainset.")
    parser.add_argument("--dev_path", type=str, required=False,default="./datasets/douban_book_review/dev.tsv",
                        help="Path of the devset.") 
    parser.add_argument("--test_path", type=str,default="./datasets/douban_book_review/test.tsv",
                        help="Path of the testset.")
    parser.add_argument("--config_path", default="./models/bert_base_config.json", type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Batch size.")
    parser.add_argument("--seq_length", type=int, default=128,
                        help="Sequence length.")
    parser.add_argument("--embedding", choices=["bert", "word"], default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.")
    parser.add_argument("--pooling", choices=["mean", "max", "first", "last"], default="first",
                        help="Pooling type.")

    # Subword options.
    parser.add_argument("--subword_type", choices=["none", "char"], default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.")

    # Tokenizer options.
    parser.add_argument("--tokenizer", choices=["bert", "char", "space"], default="bert",
                        help="Specify the tokenizer." 
                             "Original Google BERT uses bert tokenizer on Chinese corpus."
                             "Char tokenizer segments sentences into characters."
                             "Space tokenizer segments sentences into words according to space."
                             )

    # Optimizer options.
    parser.add_argument("--learning_rate", type=float, default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup", type=float, default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5,
                        help="Dropout.")
    parser.add_argument("--epochs_num", type=int, default=3,
                        help="Number of epochs.")
    parser.add_argument("--distill_epochs_num", type=int, default=5,
                        help="Number of distillation epochs.")
    parser.add_argument("--report_steps", type=int, default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7,
                        help="Random seed.")

    # Evaluation options.
    parser.add_argument("--mean_reciprocal_rank", action="store_true", help="Evaluation metrics for DBQA dataset.")
    parser.add_argument("--fast_mode", dest='fast_mode', action='store_true', help="Whether turn on fast mode")
    parser.add_argument("--speed", type=float, default=0.5, help="Threshold of Uncertainty, i.e., the Speed in paper.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Count the number of labels.
    labels_set = set()
    columns = {
    
    }
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            try:
                line = line.strip().split("\t")
                if line_id == 0:
                    for i, column_name in enumerate(line):
                        columns[column_name] = i
                    continue
                label = int(line[columns["label"]])
                labels_set.add(label)
            except:
                pass
    args.labels_num = len(labels_set) 

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)  
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)
    
    # Build classification model.
    model = FastBertClassifier(args, model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)
    
    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
            label_ids_batch = label_ids[i*batch_size: (i+1)*batch_size]
            mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
            label_ids_batch = label_ids[instances_num//batch_size*batch_size:]
            mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch

    # Build tokenizer.
    tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    # Read dataset.
    def read_dataset(path):
        dataset = []
        with open(path, mode="r", encoding="utf-8") as f:
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                try:
                    line = line.strip().split('\t')
                    if len(line) == 2:
                        label = int(line[columns["label"]])
                        text = line[columns["text_a"]]
                        tokens = [vocab.get(t) for t in tokenizer.tokenize(text)]
                        tokens = [CLS_ID] + tokens
                        mask = [1] * len(tokens)
                        if len(tokens) > args.seq_length:
                            tokens = tokens[:args.seq_length]
                            mask = mask[:args.seq_length]
                        while len(tokens) < args.seq_length:
                            tokens.append(0)
                            mask.append(0)
                        dataset.append((tokens, label, mask))
                    elif len(line) == 3: # For sentence pair input.
                        label = int(line[columns["label"]])
                        text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]

                        tokens_a = [vocab.get(t) for t in tokenizer.tokenize(text_a)]
                        tokens_a = [CLS_ID] + tokens_a + [SEP_ID]
                        tokens_b = [vocab.get(t) for t in tokenizer.tokenize(text_b)]
                        tokens_b = tokens_b + [SEP_ID]

                        tokens = tokens_a + tokens_b
                        mask = [1] * len(tokens_a) + [2] * len(tokens_b)
                        
                        if len(tokens) > args.seq_length:
                            tokens = tokens[:args.seq_length]
                            mask = mask[:args.seq_length]
                        while len(tokens) < args.seq_length:
                            tokens.append(0)
                            mask.append(0)
                        dataset.append((tokens, label, mask))
                    elif len(line) == 4: # For dbqa input.
                        qid=int(line[columns["qid"]])
                        label = int(line[columns["label"]])
                        text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]

                        tokens_a = [vocab.get(t) for t in tokenizer.tokenize(text_a)]
                        tokens_a = [CLS_ID] + tokens_a + [SEP_ID]
                        tokens_b = [vocab.get(t) for t in tokenizer.tokenize(text_b)]
                        tokens_b = tokens_b + [SEP_ID]

                        tokens = tokens_a + tokens_b
                        mask = [1] * len(tokens_a) + [2] * len(tokens_b)

                        if len(tokens) > args.seq_length:
                            tokens = tokens[:args.seq_length]
                            mask = mask[:args.seq_length]
                        while len(tokens) < args.seq_length:
                            tokens.append(0)
                            mask.append(0)
                        dataset.append((tokens, label, mask, qid))
                    else:
                        pass
                        
                except:
                    pass
        return dataset

    # Evaluation function.
    def evaluate(args, is_test, fast_mode=False):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])

        batch_size = 1
        instances_num = input_ids.size()[0]

        print("The number of evaluation instances: ", instances_num)
        print("Fast mode: ", fast_mode)

        correct = 0
        # Confusion matrix.
        confusion = torch.zeros(args.labels_num, args.labels_num, dtype=torch.long)

        model.eval()
        
        if not args.mean_reciprocal_rank:
            total_flops, model_params_num = 0, 0
            for i, (input_ids_batch, label_ids_batch,  mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)):

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                with torch.no_grad():

                    # Get FLOPs at this batch
                    inputs = (input_ids_batch, label_ids_batch, mask_ids_batch, fast_mode)
                    flops, params = profile(model, inputs, verbose=False)
                    total_flops += flops
                    model_params_num = params
                    
                    # inference
                    loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch, fast=fast_mode)

                logits = nn.Softmax(dim=1)(logits)
                pred = torch.argmax(logits, dim=1)
                gold = label_ids_batch
                for j in range(pred.size()[0]):
                    confusion[pred[j], gold[j]] += 1
                correct += torch.sum(pred == gold).item()

            print("Number of model parameters: {}".format(model_params_num))
            print("FLOPs per sample in average: {}".format(total_flops / float(instances_num)))
        
            if is_test:
                print("Confusion matrix:")
                print(confusion)
                print("Report precision, recall, and f1:")
            for i in range(confusion.size()[0]):
                # p = confusion[i,i].item()/confusion[i,:].sum().item()
                r = confusion[i,i].item()/confusion[:,i].sum().item()
                # f1 = 2*p*r / (p+r)
                if is_test:
                    print("Label {}: {:.3f}".format(i,r))
                    # print("Label {}: {:.3f}, {:.3f}, {:.3f}".format(i,p,r,f1))
            print("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(correct/len(dataset), correct, len(dataset)))
            return correct/len(dataset)
        else:
            for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)):
                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                with torch.no_grad():
                    loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch)
                logits = nn.Softmax(dim=1)(logits)
                if i == 0:
                    logits_all=logits
                if i >= 1:
                    logits_all=torch.cat((logits_all,logits),0)
        
            order = -1
            gold = []
            for i in range(len(dataset)):
                qid = dataset[i][3]
                label = dataset[i][1]
                if qid == order:
                    j += 1
                    if label == 1:
                        gold.append((qid,j))
                else:
                    order = qid
                    j = 0
                    if label == 1:
                        gold.append((qid,j))


            label_order = []
            order = -1
            for i in range(len(gold)):
                if gold[i][0] == order:
                    templist.append(gold[i][1])
                elif gold[i][0] != order:
                    order=gold[i][0]
                    if i > 0:
                        label_order.append(templist)
                    templist = []
                    templist.append(gold[i][1])
            label_order.append(templist)

            order = -1
            score_list = []
            for i in range(len(logits_all)):
                score = float(logits_all[i][1])
                qid=int(dataset[i][3])
                if qid == order:
                    templist.append(score)
                else:
                    order = qid
                    if i > 0:
                        score_list.append(templist)
                    templist = []
                    templist.append(score)
            score_list.append(templist)

            rank = []
            pred = []
            for i in range(len(score_list)):
                if len(label_order[i])==1:
                    if label_order[i][0] < len(score_list[i]):
                        true_score = score_list[i][label_order[i][0]]
                        score_list[i].sort(reverse=True)
                        for j in range(len(score_list[i])):
                            if score_list[i][j] == true_score:
                                rank.append(1 / (j + 1))
                    else:
                        rank.append(0)

                else:
                    true_rank = len(score_list[i])
                    for k in range(len(label_order[i])):
                        if label_order[i][k] < len(score_list[i]):
                            true_score = score_list[i][label_order[i][k]]
                            temp = sorted(score_list[i],reverse=True)
                            for j in range(len(temp)):
                                if temp[j] == true_score:
                                    if j < true_rank:
                                        true_rank = j
                    if true_rank < len(score_list[i]):
                        rank.append(1 / (true_rank + 1))
                    else:
                        rank.append(0)
            MRR = sum(rank) / len(rank)
            print("Mean Reciprocal Rank: {:.4f}".format(MRR))
            return MRR

    # Training phase.
    print("Start training.")
    trainset = read_dataset(args.train_path)
    random.shuffle(trainset)
    instances_num = len(trainset)
    batch_size = args.batch_size

    input_ids = torch.LongTensor([example[0] for example in trainset])
    label_ids = torch.LongTensor([example[1] for example in trainset])
    mask_ids = torch.LongTensor([example[2] for example in trainset])

    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
                {
    
    'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
                {
    
    'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=train_steps*args.warmup, t_total=train_steps)
   
    # traning main part of model
    print("Start fine-tuning the backbone of the model.")
    total_loss = 0.
    result = 0.0
    best_result = 0.0 
    for epoch in range(1, args.epochs_num+1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            model.zero_grad()

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)

            loss, _ = model(input_ids_batch, label_ids_batch, mask_ids_batch)  # training
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, backbone fine-tuning steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps))
                total_loss = 0.
            loss.backward()
            optimizer.step()
            scheduler.step()
        result = evaluate(args, False, False)
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)
        else:
            continue

    # Evaluation phase.
    if args.test_path is not None:
        print("Test set evaluation after bakbone fine-tuning.")
        model = load_model(model, args.output_model_path)
        print("Test on normal model")
        evaluate(args, True, False)
        if args.fast_mode:
            print("Test on Fast mode")
            evaluate(args, True, args.fast_mode)

    # Distillate subclassifiers
    print("Start self-distillation for student-classifiers.")
    
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
                {
    
    'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
                {
    
    'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate*10, correct_bias=False)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=train_steps*args.warmup, t_total=train_steps)

    model = load_model(model, args.output_model_path)
    total_loss = 0.
    result = 0.0
    best_result = 0.0 
    for epoch in range(1, args.distill_epochs_num+1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            model.zero_grad()

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)

            loss, _ = model(input_ids_batch, None, mask_ids_batch)  # distillation
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, self-distillation steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps))
                total_loss = 0.
            loss.backward()
            optimizer.step()
            scheduler.step()
        result = evaluate(args, False, args.fast_mode)
        save_model(model, args.output_model_path) 

        # Evaluation phase.
        if args.test_path is not None:
            print("Test set evaluation after self-distillation.")
            model = load_model(model, args.output_model_path)
            evaluate(args, True, args.fast_mode)


if __name__ == "__main__":
    main()

猜你喜欢

转载自blog.csdn.net/qq236237606/article/details/107079455
今日推荐