Use bert to train the model and convert it to pb format

The specific code is on github:
https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Train.py

def serving_input_fn():
    # 保存模型为SaveModel格式
    # 采用最原始的feature方式,输入是feature Tensors。
    # 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples
    df = pd.read_csv(FLAGS.data_dir, delimiter="\t", names=['labels', 'text'], header=None)

    dense_units = len(df.labels.unique())
    label_ids = tf.placeholder(tf.int32, [None, dense_units], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, 128], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, 128], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, 128], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
    
    
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn

Guess you like

Origin blog.csdn.net/qq236237606/article/details/107078973