textcnn加载bert向量做下游分类等任务

bert+textcnn做意图识别和文本分类

参考:https://github.com/wangle1218/KBQA-for-Diagnosis/blob/main/nlu/bert_intent_recognition/bert_model.py

https://www.bilibili.com/video/BV1d64y1v7Yq ##视频讲解

#! -*- coding: utf-8 -*-
from bert4keras.backend import keras,set_gelu
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam

set_gelu('tanh')

def textcnn(inputs,kernel_initializer):
	# 3,4,5
	cnn1 = keras.layers.Conv1D(
			256,
			3,
			strides=1,
			padding='same',
			activation='relu',
			kernel_initializer=kernel_initializer
		)(inputs) # shape=[batch_size,maxlen-2,256]
	cnn1 = keras.layers.GlobalMaxPooling1D()(cnn1)  # shape=[batch_size,256]

	cnn2 = keras.layers.Conv1D(
			256,
			4,
			strides=1,
			padding='same',
			activation='relu',
			kernel_initializer=kernel_initializer
		)(inputs)
	cnn2 = keras.layers.GlobalMaxPooling1D()(cnn2)

	cnn3 = keras.layers.Conv1D(
			256,
			5,
			strides=1,
			padding='same',
			kernel_initializer=kernel_initializer
		)(inputs)
	cnn3 = keras.layers.GlobalMaxPooling1D()(cnn3)

	output = keras.layers.concatenate(
		[cnn1,cnn2,cnn3],
		axis=-1)
	output = keras.layers.Dropout(0.2)(output)
	return output

def build_bert_model(config_path,checkpoint_path,class_nums):
	bert = build_transformer_model(
		config_path=config_path, 
		checkpoint_path=checkpoint_path, 
		model='bert', 
		return_keras_model=False)

	cls_features = keras.layers.Lambda(
		lambda x:x[:,0],
		name='cls-token'
		)(bert.model.output) #shape=[batch_size,768]
	all_token_embedding = keras.layers.Lambda(
		lambda x:x[:,1:-1],
		name='all-token'
		)(bert.model.output) #shape=[batch_size,maxlen-2,768]

	cnn_features = textcnn(
		all_token_embedding,bert.initializer) #shape=[batch_size,cnn_output_dim]
	concat_features = keras.layers.concatenate(
		[cls_features,cnn_features],
		axis=-1)

	dense = keras.layers.Dense(
			units=512,
			activation='relu',
			kernel_initializer=bert.initializer
		)(concat_features)

	output = keras.layers.Dense(
			units=class_nums,
			activation='softmax',
			kernel_initializer=bert.initializer
		)(dense)

	model = keras.models.Model(bert.model.input,output)

	return model

if __name__ == '__main__':
	config_path='***bert_wwm/bert_config.json'
	checkpoint_path='****/bert_wwm/bert_model.ckpt'
	class_nums=13
	build_bert_model(config_path, checkpoint_path, class_nums)

bert4keras 加载及训练各种下游模型

参考:https://github.com/bojone/bert4keras/tree/master/examples

cls_output = Lambda(lambda x: x[:, 0], name=‘CLS-token’)(bert.model.output) ##获取最后cls向量

all_token_output = Lambda(lambda x: x[:, 1:-1], name=‘ALL-token’)(bert.model.output) ##获取最后除cls和最后终止符的其他向量
在这里插入图片描述

***拿cls句向量直接接全连接进行多分类任务:

import numpy as np
from bert4keras.backend import keras, set_gelu, K
from bert4keras.tokenizers import Tokenizer
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam

set_gelu('tanh')  # 切换gelu版本

maxlen = 128
batch_size = 64
config_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/vocab.txt'


# 加载预训练模型
bert = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    model='albert',
    return_keras_model=False,
)

cls_output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)  ##获取最后cls向量

# all_token_output = Lambda(lambda x: x[:, 1:-1], name='ALL-token')(bert.model.output)  ##获取最后除cls和最后终止符的其他向量

output = Dense(
    units=num_classes,
    activation='softmax',
    kernel_initializer=bert.initializer
)(cls_output)

model = keras.models.Model(bert.model.input, output)

Guess you like

Origin blog.csdn.net/weixin_42357472/article/details/120892238