Pytorch使用Google BERT模型进行中文文本分类

版权声明:王家林大咖2018年新书《SPARK大数据商业实战三部曲》清华大学出版,微信公众号:从零起步学习人工智能 https://blog.csdn.net/duan_zhihua/article/details/85770837

Pytorch使用Google BERT模型进行中文文本分类

在前一篇博客中https://blog.csdn.net/duan_zhihua/article/details/85121310,我们已经实现了Tensorflow使用CNN卷积神经网络以及RNN(Lstm、Gru)循环神经网络进行中文文本分类。

BERT(Bidirectional Encoder Representations from Transformers)是Google AI团队最新开源的NLP模型,正如家林大咖所言:这是2018年人工智能领域最重要的事件!对于技术人员而言,这是整个人工智能领域接下来五年最重要的机遇!

因此,本文将使用Bert模型应用于订单类型识别案例,体验一下Bert模型的强大 !

参考的链接:

Google官方BERT代码(Tensorflow):https://github.com/google-research/bert
huggingface/pytorch-pretrained-BERT:https://github.com/huggingface/pytorch-pretrained-BERT
                                                                https://github.com/real-brilliant/bert_chinese_pytorch

https://zhuanlan.zhihu.com/p/48203943?utm_source=wechat_timeline&utm_medium=social&utm_oi=977137150768259072&wechatShare=1&from=timeline&isappinstalled=0

本文参考的基线代码:https://github.com/duanzhihua/pytorch-pretrained-BERT

                                   https://github.com/duanzhihua/bert_chinese_pytorch

一:读取数据

在参数中设置任务名称:缺省值为MyPro。
            parser.add_argument("--task_name",
                        default = 'MyPro',
                        type = str,
                        #required = True,
MyPro类自定义数据读取方法,MyPro子类继承父类DataProcessor,重载复写了get_train_examples、get_dev_examples、get_test_examples、get_labels方法,_read_csv为父类DataProcessor的文件读取方法,将读取的文本信息、分类放到字典中。子类_create_examples方法遍历字典数据,将数据封装成InputExample(guid=guid, text_a=text_a, label=label))
InputExample类中的text_b字段可选(适用于问答成对的任务),label字段可选,适用于训练集、验证集,测试集不需要label。

读入的训练集数据放入到train_examples,train_examples是一个列表,每一个元素是已经封装的InputExample,如下:

二:模型准备,加载Bert已经训练好的模型,模型配置参数如下:

01/03/2019 10:12:24 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "type_vocab_size": 2,
  "vocab_size": 21128
}
  • param_optimizer是Bert模型的优化参数列表,大小为200,列表中每一个元素是一个元组,其第一个元素是参数名,如

'bert.embeddings.token_type_embeddings.weight';第二个元素为参数值。

  • optimizer_grouped_parameters模型参数进行分组,参数名中包括['bias', 'gamma', 'beta']的分一组,其他的参数分成一组。
  •   设置Bert模型优化器。
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=t_total)

三,训练集文本数据特征提取,采用convert_examples_to_features方法提取文本特征:

 examples  : [List] 输入样本,包括question, label, index
 label_list    : [List] 所有可能的类别,可以是int、str等,如['单订单', '多订单']
 max_seq_length: [int] 文本最大长度
 tokenizer     : [Method] 分词方法

函数返回结果:

input_ids  : [ListOf] token的id,在chinese模式中就是每个分词的id,对应一个word vector
input_mask : [ListOfInt]  句子中的字符对应1,补全的字符对应0
segment_ids: [ListOfInt] 句子标识符,第一句全为0,第二句全为1
label_id   : [ListOfInt] 将Label_list转化为相应的id表示

在本案例的应用:

segment_ids是句子标识符,本文是单句文本分类,因此全为0,segment_ids:<class 'list'>: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

句子单词的向量化:每个句子截取序列最大长度-2的单词,前面补充一个单词[CLS],末尾补充一个单词[SEP]:
如: ['[CLS]', '人', '工', '受', '理', '收', '费', '模', '板', '大', '备', '注', '甩', '单', '|', '|', '拆', '机', '|', '|', '拆', '[SEP]']
将每个单词从词汇表中查询到相应的编号放入input_ids列表,如:
input_ids : [101, 782, 2339, 1358, 4415, 3119, 6589, 3563, 3352, 1920, 1906, 3800, 4501, 1296, 170, 170, 2858, 3322, 170, 170, 2858, 102]                        
input_mask:单词掩码,在句子里面的单词掩码为1,如给最大句子长度凑数而填充的单词掩码为0
label_id:标签的ID。
show_exp=True 可以打开调试开关,打印前5条数据的信息。

四,模型训练之Pytorch数据集加载。
1,将input_ids , input_mask,segment_ids,label_id封装成Pytorch的torch.tensor。
2,将input_ids , input_mask,segment_ids,label_id已封装的torch.tensor元组,封装成TensorDataset类的实例train_data,TensorDataset类继承至Dataset类。
3,train_sampler将train_data进行随机打乱,将train_data数据集、随机采样器、批处理大小封装到train_dataloader训练集。

五,模型训练:BertForSequenceClassification 

1,设置模型为训练模式。 
2,循环遍历train_dataloader的数据,将每一批次的数据input_ids, segment_ids, input_mask, label_ids传入Bert模型,
在__call__方法调用pytorch_pretrained_bert中modeling模型的forward方法前向推导计算loss值。
loss = model(input_ids, segment_ids, input_mask, label_ids)
3,loss.backward()反向推导,optimizer.step()更新权重。 

六:模型验证:pred = logits.max(1)[1]  # 取预测最大值的分类索引。将预测分类与实际分类计算metrics.f1_score得分。

        with torch.no_grad():
            logits = model(input_ids, segment_ids, input_mask)         
            pred = logits.max(1)[1]
            predict = np.hstack((predict, pred.cpu().numpy()))
            gt = np.hstack((gt, label_ids.cpu().numpy()))

  模型测试:test(model, processor, args, label_list, tokenizer, device)
 

七:模型预测:写一个预测代码,用于测试集预测以后,将预测结果与甩单号进行关联。

结果如下:

测试时发现,在测试数据集小的情况下,Pytorch Bert optimization.py代码还需进行一个小改动, 加上 group['t_total'] !=0的判断:

               if group['t_total'] != -1 and group['t_total'] !=0:				
                    schedule_fct = SCHEDULES[group['schedule']]
                    lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
      

Pytorch Bert 下载的模型约373M,如图:

其中的部分单词表:

[PAD]
[unused1]
[unused2]
[unused3]
[unused4]
[unused5]
[unused6]
[unused7]
[unused8]
[unused9]
[unused10]
[unused11]
[unused12]
[unused13]
[unused14]
[unused15]
[unused16]
[unused17]
[unused18]
[unused19]
[unused20]
[unused21]
[unused22]
[unused23]
[unused24]
[unused25]
[unused26]
[unused27]
[unused28]
[unused29]
[unused30]
[unused31]
[unused32]
[unused33]
[unused34]
[unused35]
[unused36]
[unused37]
[unused38]
[unused39]
[unused40]
[unused41]
[unused42]
[unused43]
[unused44]
[unused45]
[unused46]
[unused47]
[unused48]
[unused49]
[unused50]
[unused51]
[unused52]
[unused53]
[unused54]
[unused55]
[unused56]
[unused57]
[unused58]
[unused59]
[unused60]
[unused61]
[unused62]
[unused63]
[unused64]
[unused65]
[unused66]
[unused67]
[unused68]
[unused69]
[unused70]
[unused71]
[unused72]
[unused73]
[unused74]
[unused75]
[unused76]
[unused77]
[unused78]
[unused79]
[unused80]
[unused81]
[unused82]
[unused83]
[unused84]
[unused85]
[unused86]
[unused87]
[unused88]
[unused89]
[unused90]
[unused91]
[unused92]
[unused93]
[unused94]
[unused95]
[unused96]
[unused97]
[unused98]
[unused99]
[UNK]
[CLS]
[SEP]
[MASK]
<S>
<T>
!
"
#
$
%
&
'
(
)
*
+
,
-
.......

本文在Windows 环境进行了小量数据的测试。Bert模型沿用网友Geek Fly的测试结果,如下:

猜你喜欢

转载自blog.csdn.net/duan_zhihua/article/details/85770837