NLP文本分类pytorch框架-支持Bert等预训练语言模型

基本信息

基于Pytorch的文本分类框架。

同时支持中英文的数据集的文本分类任务。

项目地址:https://github.com/wzzzd/text_classifier_pytorch

支持的模型

  • 非预训练类模型:
    • FastText
    • TextCNN
    • TextRNN
    • TextRCNN
    • Transformer
  • 预训练类模型
    • Bert
    • Albert
    • Roberta
    • Distilbert
    • Electra
    • XLNet

支持的训练方式

  • 支持中英文语料训练
    • 支持中英文的文本分类任务。
  • 支持多种模型使用
    • 配置文件Config.py中的变量model_name表示模型名称,可以更改成你想要加载的模型名称。
    • 若是预训练类的模型,如Bert等,需要同步修改变量initial_pretrain_modelinitial_pretrain_tokenizer,修改为你想要加载的预训练参数。
  • 混合精度训练
    • 用于提升训练过程效率,缩短训练时间。
    • 配置文件Config.py中的变量fp16值改为True
  • GPU多卡训练
    • 用于分布式训练,支持单机单卡、多卡训练。
    • 配置文件Config.py中的变量cuda_visible_devices用于设置可见的GPU卡号,多卡情况下用,间隔开。
  • 对抗训练
    • 在模型embedding层增加扰动,使模型学习对抗扰动,提升表现,需要额外增加训练时间。
    • 配置文件Config.py中的变量adv_option用于设置可见的对抗模式,目前支持FGM/PGD。
  • 对比学习
    • 用于增强模型语义特征提取能力,借鉴Rdrop和SimCSE的思想,目前支持KL loss和InfoNCE两种损失。
    • 配置文件Config.py中的变量cl_option设置为True则表示开启对比学习模式,cl_method用于设置计算对比损失的方法。

数据集

  • THUCNews

    • 来自:https://github.com/649453932/Chinese-Text-Classification-Pytorch
    • 关于THUCNews的的数据。
    • 数据分为10个类标签类别,分别为:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐
  • 加入自己的数据集

    • 可使用本项目的处理方式,将数据集切分为3部分:train/valid/test,其中token和label之间用制表符\t分割。
    • 在 ./dataset 目录下新建一个文件夹,并把3个数据文件放置新建文件夹下。
  • 数据集示例

    • 以数据集THUCNews为栗子,文本和标签使用空格隔开,采用以下形式存储:
        午评沪指涨0.78%逼近2800 汽车家电农业领涨	2
        卡佩罗:告诉你德国脚生猛的原因 不希望英德战踢点球	7
    

实验

说明:预训练模型基于transformers框架,如若想要替换成其他预训练参数,可以查看transformers官方网站

模型名称 MicroF1 LearningRate 预训练参数
FastText 0.8926 1e-3 -
TextCNN 0.9009 1e-3 -
TextRNN 0.9080 1e-3 -
TextRCNN 0.9142 1e-3 -
Tramsformer(2 layer) 0.8849 1e-3 -
Albert 0.9124 2e-5 voidful/albert_chinese_tiny
Distilbert 0.9209 2e-5 Geotrend/distilbert-base-zh-cased
Bert 0.9401 2e-5 bert-base-chinese
Roberta 0.9448 2e-5 hfl/chinese-roberta-wwm-ext
Electra 0.9377 2e-5 hfl/chinese-electra-base-discriminator
XLNet 0.9051 2e-5 无参数初始化

环境配置

Python使用的是3.6.X版本,其他依赖模块如下:

    numpy==1.19.2
    pandas==1.1.5
    scikit_learn==1.0.2
    torch==1.8.0
    tqdm==4.62.3
    transformers==4.15.0
    apex==0.1

除了apex需要额外安装(参考官网:https://github.com/NVIDIA/apex
),其他模块可通过以下命令安装依赖包

    pip install -r requirement.txt

如何使用项目代码

1. 训练

准备好训练数据后,终端可运行命令

    python3 main.py

2 测试评估

加载已训练好的模型,并使用valid set作模型测试,输出文件到 ./dataset/${your_dataset}/output/output.txt 目录下。

需要修改Config文件中的变量值mode = 'test',并保存。

终端可运行命令

    python3 main.py

参考

[Github:transformers] https://github.com/huggingface/transformers

[Paper:Bert] https://arxiv.org/abs/1810.04805

[Paper:RDrop] https://arxiv.org/abs/2106.14448

[Paper:SimCSE] https://arxiv.org/abs/2104.08821

猜你喜欢

转载自blog.csdn.net/Neleuska/article/details/122492386