NLP分类任务中的损失函数的选择

NLP分类任务可以分为单标签和多标签,在进行这些分类任务的时候,怎么选择损失函数呢?

一、单标签分类任务

单标签分类任务很简单,就是一个样本只有一个标签;进一步,根据类别的多少可以分为二分类和多分类。

1、二分类任务

只有2个类别,非A即B,那么这种情况下,可以采用如下的方式:

a、sigmoid激活函数+BCELoss

训练代码实现方式如下

#output [B,C]
output = torch.sigmoid(output)
loss = f.binary_cross_entropy(output,labels.float())

当然BCEWithLogitsLoss是和sigmoid+BCE是等价的,因此可以直接实现

# output [B,C]  labels [B,C]
loss = f.binary_cross_entropy_with_logits(output,labels.float())

b、softmax+交叉熵

这里说法有点不准确,准确的说法应该是softmax+NLLLoss==CrossEntropyLoss

2分类其实就是特殊的多分类,至于按照这种思路来做和前面的——igmoid激活函数+BCELoss——来做有什么优缺点,这里没有深究。

训练代码实现如下:

# output [B,C]  labels [B,C]
output = torch.softmax(output,dim=1)
loss = f.nll_loss(output,labels)

或者如下

# output [B,C]  labels [B,C]
loss = f.cross_entropy(output,labels)

c、softmax+BCE

其实很少有人这么使用,但是我认为这样使用也是可以的,把模型的输出经过softmax之后,输出张量的值域化为[0,1],然后再经过BCEloss,同样能完成分类模型的训练,实现代码如下:

# output [B,C]  labels [B,C]
output = torch.softmax(output,dim=1)
loss = f.binary_cross_entropy(output,labels.float())

2、多分类

一个样本只有一个label,在所有的类别中挑选出合适的一种,那么这样的单标签多分类任务都是采用softmax+交叉熵的思路来做。代码实现同上:

# output [B,C]  labels [B,C]
output = torch.softmax(output,dim=1)
loss = f.nll_loss(output,labels)

或者如下

# output [B,C]  labels [B,C]
loss = f.cross_entropy(output,labels)

二、多标签分类

就是每一个样本可能具有多个标签,标签数目小于等于最大标签数目,也可能是0个。一般比较朴素的做法就是分解成多个2分类任务,假如有n个label,那就做n个2分类任务。在深度学习的实现上,可以采用 ——sigmoid激活函数+BCELoss——的思路来实现,针对每一个label,由全连接得到一个分数,然后使用sigmoid激活函数把它的值化为[0,1]之间,再使用BCELoss使得模型迭代优化。推理的时候,针对每一类label设定一个阈值,把得分大于阈值的就视为目标类。

label可能就是这样子的:[0,0,0,1,0,1,0,0,1]、[0,0,0,0,0,0,0,0,1]、[1,0,0,0,0,1,0,0,1]等,其中0表示非目标label,1表示目标label。实现方式如下:

#output [B,C]
output = torch.sigmoid(output)
loss = f.binary_cross_entropy(output,labels.float())

当然BCEWithLogitsLoss是和sigmoid+BCE是等价的,因此可以直接实现

# output [B,C]  labels [B,C]
loss = f.binary_cross_entropy_with_logits(output,labels.float())

这样的实现有一个缺点——类别不平衡影响模型的性能,根据苏剑林的文章——将“softmax+交叉熵”推广到多标签分类问题——可以缓解这个问题。最终公式如下:

按照文章的公式和代码,这里采用pytorch实现了loss函数:

def multilabel_crossentropy(output,label):
    """
    多标签分类的交叉熵
    说明:label和output的shape一致,label的元素非0即1,
         1表示对应的类为目标类,0表示对应的类为非目标类。
    警告:请保证output的值域是全体实数,换言之一般情况下output
         不用加激活函数,尤其是不能加sigmoid或者softmax!预测
         阶段则输出output大于0的类。如有疑问,请仔细阅读并理解
         本文。
    :param output: [B,C]
    :param label:  [B,C]
    :return:
    """
    output = (1-2*label)*output

    #得分变为负1e12
    output_neg = output - label* 1e12
    output_pos = output-(1-label)* 1e12

    zeros = torch.zeros_like(output[:,:1])

    # [B, C + 1]
    output_neg = torch.cat([output_neg,zeros],dim=1)
    # [B, C + 1]
    output_pos = torch.cat([output_pos,zeros],dim=1)

    
    loss_pos = torch.logsumexp(output_pos,dim=1)
    loss_neg = torch.logsumexp(output_neg,dim=1)
    loss = (loss_neg + loss_pos).sum()

    return loss

三、多标签分类实战

这里主要是做一个多标签分类的实验,一直没有实实在在做过多标签的任务;另外一个就是验证一下苏神的multilabel_crossentropy损失函数的有效性和高效性。

数据集采用——CCKS2021 运营商知识图谱推理问答——比赛给定初赛的训练集,按照train:dev ==4:1的比例切分,数据格式如下:

 把答案类型和属性名两列作为label值,合并得到多标签,就可以按照多标签任务来做。模型的定义代码很简单,直接Bert+mean pooling+linear得到模型输出,然后计算损失。上模型训练的代码,如下:

import torch
import argparse
from data_reader.dataReader import DataReader
from model.mutilLabel_classification import MutilLabelClassification
from torch.utils.data import DataLoader

import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from transformers import BertTokenizer,BertConfig
import os
from tools.log import Logger
from tools.progressbar import ProgressBar
from datetime import datetime

logger = Logger('mutil_label_logger',log_level=10)
os.environ['CUDA_VISIBLE_DEVICES'] = "0"


def parse_args():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--max_len",type=int,default=64)
    parser.add_argument("--train_file", type=str,default='./data/train.xlsx', help="train text file")
    parser.add_argument("--val_file", type=str, default='./data/dev.xlsx',help="val text file")
    parser.add_argument("--pretrained", type=str, default="./pretrain_models/chinese-bert-wwm-ext", help="huggingface pretrained model")
    parser.add_argument("--model_out", type=str, default="./output", help="model output path")
    parser.add_argument("--batch_size", type=int, default=32, help="batch size")
    parser.add_argument("--epochs", type=int, default=20, help="epochs")
    parser.add_argument("--lr", type=int, default=1e-5, help="epochs")
    parser.add_argument("--loss_function_type",type=str,default='MLCE')
    args = parser.parse_args()
    return args


def multilabel_crossentropy(output,label):
    """
    多标签分类的交叉熵
    说明:label和output的shape一致,label的元素非0即1,
         1表示对应的类为目标类,0表示对应的类为非目标类。
    警告:请保证output的值域是全体实数,换言之一般情况下output
         不用加激活函数,尤其是不能加sigmoid或者softmax!预测
         阶段则输出output大于0的类。如有疑问,请仔细阅读并理解
         本文。
    :param output: [B,C]
    :param label:  [B,C]
    :return:
    """
    output = (1-2*label)*output

    #得分变为负1e12
    output_neg = output - label* 1e12
    output_pos = output-(1-label)* 1e12

    zeros = torch.zeros_like(output[:,:1])

    # [B, C + 1]
    output_neg = torch.cat([output_neg,zeros],dim=1)
    # [B, C + 1]
    output_pos = torch.cat([output_pos,zeros],dim=1)


    loss_pos = torch.logsumexp(output_pos,dim=1)
    loss_neg = torch.logsumexp(output_neg,dim=1)
    loss = (loss_neg + loss_pos).sum()

    return loss





def train(args):
    logger.info(args)
    tokenizer = BertTokenizer.from_pretrained(args.pretrained)
    config = BertConfig.from_pretrained(args.pretrained)

    with open('data/labels.txt','r',encoding='utf-8') as f:
        lines = f.readlines()
    config.num_labels = len(lines)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # device = "cpu"
    model = MutilLabelClassification.from_pretrained(config=config, pretrained_model_name_or_path=args.pretrained,
                                         max_len=args.max_len)
    model.to(device)


    train_dataset = DataReader(tokenizer=tokenizer,filepath=args.train_file,max_len=args.max_len)
    train_dataloader = DataLoader(train_dataset,batch_size=args.batch_size,shuffle=True)

    val_dataset = DataReader(tokenizer=tokenizer,filepath=args.val_file,max_len=args.max_len)
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)

    optimizer = AdamW(model.parameters(),lr=args.lr)
    scheduler = ReduceLROnPlateau(optimizer=optimizer,mode='max',factor=0.5, patience=2)

    model.train()
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d"% len(train_dataloader))
    logger.info("  Num Epochs = %d"%args.epochs)
    best_acc = 0.0
    for epoch in range(args.epochs):
        pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
        for step,batch in enumerate(train_dataloader):
            batch = [t.to(device) for t in batch]
            inputs = {'input_ids':batch[0],'attention_mask':batch[1],'token_type_ids':batch[2]}
            labels = batch[3]
            output = model(inputs)
            if args.loss_function_type == "BCE":
                # 此处BCELoss的输入labels类型是必须和output一样的
                loss = F.binary_cross_entropy_with_logits(output,labels.float())
            else:
                #多标签分类交叉熵
                loss = multilabel_crossentropy(output,labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar(step, {'loss':loss.item()})


        time_srt = datetime.now().strftime('%Y-%m-%d')

        train_acc = valdation(model,train_dataloader,device,args)

        val_acc = valdation(model,val_dataloader,device,args)
        scheduler.step(val_acc)

        if val_acc > best_acc:
            best_acc = val_acc
            save_path = os.path.join(args.model_out,args.loss_function_type,"BertMutilLalelClassification"+time_srt)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            logger.info("save model")
            model.save_pretrained(save_path)
            tokenizer.save_vocabulary(save_path)
        # logger.info("train_acc: %.4f------val_acc:%.4f------best_acc:%.4f"%(train_acc,val_acc,best_acc))
        logger.info(args.loss_function_type+" train_acc:%.4f val_acc:%.4f------best_acc:%.4f" % (train_acc, val_acc, best_acc))


def valdation(model,val_dataloader,device,args):
    total = 0
    total_correct = 0
    model.eval()
    with torch.no_grad():
        pbar = ProgressBar(n_total=len(val_dataloader), desc='evaldation')
        for step, batch in enumerate(val_dataloader):
            batch = [t.to(device) for t in batch]
            inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2]}
            labels = batch[3]
            output = model(inputs)

            #注意这里统计模型指标正确率的代码逻辑,torch.where()和torch.equal()
            if args.loss_function_type == "BCE":
                output = torch.sigmoid(output)
                pred = torch.where(output>0.5,1,0)
            else:
                pred = torch.where(output>0,1,0)
            correct = 0
            for i in range(labels.size()[0]):
                if torch.equal(pred[i],labels[i]):
                    correct +=1
            total_correct += correct
            total += labels.size()[0]
            pbar(step,{})
        acc = total_correct/total
        return acc


def main():
    args =parse_args()
    train(args)


if __name__ == '__main__':
    main()

以上,注意模型评价指标,这里只简单的采用了acc并没有采用宏平均或者微平均来详细的评价模型性能。最终显示模型采用sigmoid和BCELoss的结果如下:

BCE结果
tokenization: 4000it [00:00, 5730.70it/s]
tokenization: 1000it [00:00, 5646.53it/s]
2021-09-13 21:40:02,572 log.py [line:49] INFO ***** Running training *****
2021-09-13 21:40:02,572 log.py [line:49] INFO   Num examples = 125
2021-09-13 21:40:02,572 log.py [line:49] INFO   Num Epochs = 20
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:40:15,125 log.py [line:49] INFO save model
2021-09-13 21:40:20,092 log.py [line:49] INFO BCE train_acc:0.4713 val_acc:0.4380------best_acc:0.4380
[evaldation] 32/32 [==============================] 20.0ms/step2021-09-13 21:40:32,027 log.py [line:49] INFO save model
2021-09-13 21:40:37,070 log.py [line:49] INFO BCE train_acc:0.8085 val_acc:0.7710------best_acc:0.7710
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:40:49,029 log.py [line:49] INFO save model
2021-09-13 21:40:55,869 log.py [line:49] INFO BCE train_acc:0.9260 val_acc:0.8890------best_acc:0.8890
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:41:07,848 log.py [line:49] INFO save model
2021-09-13 21:41:11,973 log.py [line:49] INFO BCE train_acc:0.9527 val_acc:0.9100------best_acc:0.9100
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:41:23,953 log.py [line:49] INFO save model
2021-09-13 21:41:29,065 log.py [line:49] INFO BCE train_acc:0.9830 val_acc:0.9350------best_acc:0.9350
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:41:41,056 log.py [line:49] INFO save model
2021-09-13 21:41:47,433 log.py [line:49] INFO BCE train_acc:0.9888 val_acc:0.9540------best_acc:0.9540
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:41:59,427 log.py [line:49] INFO save model
2021-09-13 21:42:04,007 log.py [line:49] INFO BCE train_acc:0.9872 val_acc:0.9560------best_acc:0.9560
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:42:16,013 log.py [line:49] INFO save model
2021-09-13 21:42:20,534 log.py [line:49] INFO BCE train_acc:0.9932 val_acc:0.9590------best_acc:0.9590
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:42:32,541 log.py [line:49] INFO BCE train_acc:0.9930 val_acc:0.9560------best_acc:0.9590
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:42:44,537 log.py [line:49] INFO save model
2021-09-13 21:42:50,451 log.py [line:49] INFO BCE train_acc:0.9965 val_acc:0.9640------best_acc:0.9640
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:43:02,462 log.py [line:49] INFO save model
2021-09-13 21:43:08,752 log.py [line:49] INFO BCE train_acc:0.9972 val_acc:0.9650------best_acc:0.9650
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:43:20,762 log.py [line:49] INFO save model
2021-09-13 21:43:24,950 log.py [line:49] INFO BCE train_acc:0.9982 val_acc:0.9660------best_acc:0.9660
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:43:36,944 log.py [line:49] INFO BCE train_acc:0.9985 val_acc:0.9660------best_acc:0.9660
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:43:48,954 log.py [line:49] INFO save model
2021-09-13 21:43:53,782 log.py [line:49] INFO BCE train_acc:0.9995 val_acc:0.9670------best_acc:0.9670
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:44:05,799 log.py [line:49] INFO BCE train_acc:0.9990 val_acc:0.9650------best_acc:0.9670
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:44:17,806 log.py [line:49] INFO save model
2021-09-13 21:44:22,550 log.py [line:49] INFO BCE train_acc:0.9995 val_acc:0.9720------best_acc:0.9720
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:44:34,569 log.py [line:49] INFO BCE train_acc:0.9998 val_acc:0.9690------best_acc:0.9720
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:44:46,577 log.py [line:49] INFO BCE train_acc:1.0000 val_acc:0.9710------best_acc:0.9720
[evaldation] 32/32 [==============================] 20.1ms/step2021-09-13 21:44:58,581 log.py [line:49] INFO BCE train_acc:0.9995 val_acc:0.9710------best_acc:0.9720
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:45:10,598 log.py [line:49] INFO BCE train_acc:1.0000 val_acc:0.9720------best_acc:0.9720

采用multilabel_crossentropy的结果如下:

MLCE(多标签交叉熵)
tokenization: 4000it [00:00, 5808.90it/s]
tokenization: 1000it [00:00, 5721.61it/s]
2021-09-13 21:48:09,971 log.py [line:49] INFO ***** Running training *****
2021-09-13 21:48:09,972 log.py [line:49] INFO   Num examples = 125
2021-09-13 21:48:09,972 log.py [line:49] INFO   Num Epochs = 20
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:48:22,566 log.py [line:49] INFO save model
2021-09-13 21:48:26,703 log.py [line:49] INFO MLCE train_acc:0.7645 val_acc:0.7550------best_acc:0.7550
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:48:38,664 log.py [line:49] INFO save model
2021-09-13 21:48:45,229 log.py [line:49] INFO MLCE train_acc:0.9570 val_acc:0.9320------best_acc:0.9320
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:48:57,208 log.py [line:49] INFO save model
2021-09-13 21:49:03,307 log.py [line:49] INFO MLCE train_acc:0.9878 val_acc:0.9600------best_acc:0.9600
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:49:15,290 log.py [line:49] INFO save model
2021-09-13 21:49:24,350 log.py [line:49] INFO MLCE train_acc:0.9958 val_acc:0.9680------best_acc:0.9680
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:49:36,328 log.py [line:49] INFO save model
2021-09-13 21:49:41,456 log.py [line:49] INFO MLCE train_acc:0.9980 val_acc:0.9700------best_acc:0.9700
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:49:53,465 log.py [line:49] INFO save model
2021-09-13 21:50:02,092 log.py [line:49] INFO MLCE train_acc:0.9988 val_acc:0.9740------best_acc:0.9740
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:50:14,090 log.py [line:49] INFO save model
2021-09-13 21:50:20,566 log.py [line:49] INFO MLCE train_acc:0.9998 val_acc:0.9760------best_acc:0.9760
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:50:32,564 log.py [line:49] INFO MLCE train_acc:0.9992 val_acc:0.9720------best_acc:0.9760
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:50:44,563 log.py [line:49] INFO MLCE train_acc:0.9988 val_acc:0.9720------best_acc:0.9760
[evaldation] 32/32 [==============================] 20.2ms/step2021-09-13 21:50:56,576 log.py [line:49] INFO MLCE train_acc:0.9988 val_acc:0.9750------best_acc:0.9760
[evaldation] 32/32 [==============================] 20.4ms/step2021-09-13 21:51:08,585 log.py [line:49] INFO save model
2021-09-13 21:51:12,692 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9790------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:51:24,729 log.py [line:49] INFO MLCE train_acc:0.9998 val_acc:0.9780------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:51:36,768 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9790------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:51:48,789 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9780------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:52:00,816 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9770------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:52:12,831 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9790------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:52:24,863 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9790------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:52:36,892 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9790------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:52:48,917 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9790------best_acc:0.9790
[evaldation] 32/32 [==============================] 20.3ms/step2021-09-13 21:53:00,951 log.py [line:49] INFO MLCE train_acc:1.0000 val_acc:0.9790------best_acc:0.9790

最终验证集准确率,MLCE为0.9790,BCE为0.9720;模型收敛速度也是MLCE明显快于BCE,由上面的训练结果可以看到大致在第二个epoch模型采用MLCE就收敛的比较好了。虽然准确率提升不是特别大,但是这个loss确实是不需要什么额外的操作就能获得很好的效果,值得学习和掌握!

全部代码地址

A PyTorch implementation of mutil_label text classification

参考文章

将“softmax+交叉熵”推广到多标签分类问题

pytorch官方Loss Functions

猜你喜欢

转载自blog.csdn.net/HUSTHY/article/details/120262127