Keras implements the encapsulation of bert encoding

The bert coding raid the major NLP lists. In order to make it easier for everyone to test the power of bert on their own data sets, share here the keras version of the nert coding layer. Go directly to the code.

class b_embeding_layer_b(keras.layers.Layer):
    """自定义层"""

    def __init__(self, max_seq_len=50,model_dir = r"F:\glove.6B\chinese_L-12_H-768_A-12",mode="word", **kwargs):
        self.max_seq_len = max_seq_len
        self.model_dir = model_dir
        self.mode = mode

        super(b_embeding_layer_b, self).__init__(**kwargs)


    def build(self, input_shape):
        #构建训练参数
        pass

    def call(self, x, mask=None):
        base_location = 'F:/glove.6B/chinese_L-12_H-768_A-12/'
        bert_config = BertConfig.from_json_file(base_location + 'bert_config.json')
        init_checkpoint = base_location + 'bert_model.ckpt'
        model = BertModel(
            config=bert_config,
            is_training=True,
            input_ids=tf.cast(x[0], tf.int32),
            input_mask=tf.cast(x[1], tf.int32),
            token_type_ids=tf.cast(x[2], tf.int32),
            use_one_hot_embeddings=False)
        tvars = tf.trainable_variables()
        scaffold_fn = None
        (assignment_map, _) = get_assignment_map_from_checkpoint(
            tvars, init_checkpoint)
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
        # 获取对应的embedding 输入数据[batch_size, seq_length, embedding_size]
        embedding = None
        if self.mode == "sent":  #如此判断会有问题 --- 从jsom加载时不会有初始化时候的传值,只有默认值
            embedding = model.get_pooled_output()
        if self.mode == "word":
            embedding = model.get_sequence_output()

        return embedding

    def compute_output_shape(self, input_shape):
        if self.mode == "word":
            return input_shape[0][0], self.max_seq_len,768
        if self.mode == "sent":
            return input_shape[0][0],  768

    def get_config(self):
        config = {
    
    'mode': self.mode,
                  "max_seq_len":self.max_seq_len
                  }
        base_config = super(b_embeding_layer_b, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

The calling method has corresponding content in the NER model shared earlier, so I won't repeat it here.

Guess you like

Origin blog.csdn.net/cyinfi/article/details/90349981