bert 文本多分类

参考:https://github.com/xmxoxo/BERT-train2deploy

1. 准备

  1. 下载bert源码
    https://github.com/google-research/bert.git

  2. 下载bert预训练模型,本文使用中文预训练模型
    https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip

  • bert_model.ckpt:负责模型变量载入
  • vocab.txt:字典
  • bert_config.json:bert训练时的可调参数
    在这里插入图片描述

2. 语料

建议采用 \t 进行分割,第一列为标签,第二列为文本,训练集保存为train.tsv,测试集保存为test.tsv,测试集可以仅有1列,无标签。如下图语料有3个标签。将 train.tsv、val.tsv、test.tsv放入同一目录下。

0	光度是光作用于人眼所引起的明亮程度的感觉。
1	黄樟素属于易制毒化学品。
2	在佛教中,对僧人的称呼一般有

3. 建立bert多分类模型

bert多分类仅需修改 run_classifier.py 文件中的代码

  1. 在main()中添加多分类任务 classifytask,该任务类 ClassifyProcessor 继承 DataProcessor,与其它的cola、mnli等任务类似
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {
    
    
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "classifytask": ClassifyProcessor,     # 自定义的多分类任务
  }
  1. 编写ClassifyProcessor类
class ClassifyProcessor(DataProcessor):
    # 与训练集中定义的标签名一致
    def __init__(self):
        self.labels = [0, 1, 2]
    # 加载训练集
    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
    # 加载验证集
    def get_dev_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "val.tsv")), "val")
    # 加载测试集
    def get_test_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
    # 获取标签
    def get_labels(self):
        return self.labels
    # 读取数据
    def _create_examples(self, lines, set_type):
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                # 测试集没有标签,因此可以设置一个默认值
	            text_a = tokenization.convert_to_unicode(line[0])
	            label = 0
	        else:
	            text_a = tokenization.convert_to_unicode(line[1])
	            label = tokenization.convert_to_unicode(line[0])
	        examples.append(
	            InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
	
	        return examples

4. 模型训练

新建脚本 start.sh,输入下列内容。

export DATA_DIR=           # 语料路径
export BERT_BASE_DIR=      # 预训练模型路径

python run_classifier.py \
 --task_name=classifytask \
 --do_train=true \         # 是否进行fine tune
 --do_eval=true \          # 是否进行evaluation
 --data_dir=$DATA_DIR/ \
 --vocab_file=$BERT_BASE_DIR/vocab.txt \
 --bert_config_file=$BERT_BASE_DIR/bert_config.json \
 --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
 --max_seq_length=128 \    # 句子的最长长度
 --train_batch_size=32 \
 --learning_rate=2e-5 \
 --num_train_epochs=3.0 \
 --output_dir=$BERT_BASE_DIR/output      # 输出目录

运行sh命令,bert分类任务开始执行,建议采用gpu进行训练,否则训练过程可能在Saving checkpoints for 0 … 处卡住。

sh start.sh

在这里插入图片描述
在这里插入图片描述

5. 预测

训练完成后得到的模型文件如图,后续的预测任务即可调用该模型进行,通常会采用ckpt模型生成pb模型文件进行调用。
使用下列命令进行预测,将init_checkpoint定义为训练生成的模型即可。

export BERT_BASE_DIR=chinese_L-12_H-768_A-12
export NER_DIR=dat
export OUTPUT=output
python run_mobile.py \
          --task_name=setiment \
          --do_predict=true \
          --data_dir=$NER_DIR/ \
          --vocab_file=$BERT_BASE_DIR/vocab.txt \
          --bert_config_file=$BERT_BASE_DIR/bert_config.json \
          --init_checkpoint=$OUTPUT/model.ckpt-455 \
          --max_seq_length=128 \
          --output_dir=$OUTPUT/

在这里插入图片描述
输出的预测文件如图,一行为一条记录对应每个类别的概率,将概率进行转换即可得到分类标签。本次训练的准确率约为85%左右。
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/shlhhy/article/details/107382079