bert+textcnn做意图识别和文本分类
参考:https://github.com/wangle1218/KBQA-for-Diagnosis/blob/main/nlu/bert_intent_recognition/bert_model.py
https://www.bilibili.com/video/BV1d64y1v7Yq ##视频讲解
#! -*- coding: utf-8 -*-
from bert4keras.backend import keras,set_gelu
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam
set_gelu('tanh')
def textcnn(inputs,kernel_initializer):
# 3,4,5
cnn1 = keras.layers.Conv1D(
256,
3,
strides=1,
padding='same',
activation='relu',
kernel_initializer=kernel_initializer
)(inputs) # shape=[batch_size,maxlen-2,256]
cnn1 = keras.layers.GlobalMaxPooling1D()(cnn1) # shape=[batch_size,256]
cnn2 = keras.layers.Conv1D(
256,
4,
strides=1,
padding='same',
activation='relu',
kernel_initializer=kernel_initializer
)(inputs)
cnn2 = keras.layers.GlobalMaxPooling1D()(cnn2)
cnn3 = keras.layers.Conv1D(
256,
5,
strides=1,
padding='same',
kernel_initializer=kernel_initializer
)(inputs)
cnn3 = keras.layers.GlobalMaxPooling1D()(cnn3)
output = keras.layers.concatenate(
[cnn1,cnn2,cnn3],
axis=-1)
output = keras.layers.Dropout(0.2)(output)
return output
def build_bert_model(config_path,checkpoint_path,class_nums):
bert = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
model='bert',
return_keras_model=False)
cls_features = keras.layers.Lambda(
lambda x:x[:,0],
name='cls-token'
)(bert.model.output) #shape=[batch_size,768]
all_token_embedding = keras.layers.Lambda(
lambda x:x[:,1:-1],
name='all-token'
)(bert.model.output) #shape=[batch_size,maxlen-2,768]
cnn_features = textcnn(
all_token_embedding,bert.initializer) #shape=[batch_size,cnn_output_dim]
concat_features = keras.layers.concatenate(
[cls_features,cnn_features],
axis=-1)
dense = keras.layers.Dense(
units=512,
activation='relu',
kernel_initializer=bert.initializer
)(concat_features)
output = keras.layers.Dense(
units=class_nums,
activation='softmax',
kernel_initializer=bert.initializer
)(dense)
model = keras.models.Model(bert.model.input,output)
return model
if __name__ == '__main__':
config_path='***bert_wwm/bert_config.json'
checkpoint_path='****/bert_wwm/bert_model.ckpt'
class_nums=13
build_bert_model(config_path, checkpoint_path, class_nums)
bert4keras 加载及训练各种下游模型
参考:https://github.com/bojone/bert4keras/tree/master/examples
cls_output = Lambda(lambda x: x[:, 0], name=‘CLS-token’)(bert.model.output) ##获取最后cls向量
all_token_output = Lambda(lambda x: x[:, 1:-1], name=‘ALL-token’)(bert.model.output) ##获取最后除cls和最后终止符的其他向量
***拿cls句向量直接接全连接进行多分类任务:
import numpy as np
from bert4keras.backend import keras, set_gelu, K
from bert4keras.tokenizers import Tokenizer
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam
set_gelu('tanh') # 切换gelu版本
maxlen = 128
batch_size = 64
config_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/vocab.txt'
# 加载预训练模型
bert = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
model='albert',
return_keras_model=False,
)
cls_output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output) ##获取最后cls向量
# all_token_output = Lambda(lambda x: x[:, 1:-1], name='ALL-token')(bert.model.output) ##获取最后除cls和最后终止符的其他向量
output = Dense(
units=num_classes,
activation='softmax',
kernel_initializer=bert.initializer
)(cls_output)
model = keras.models.Model(bert.model.input, output)