pretrain两个任务:
论文不使用传统的从左到右或从右到左的语言模型来预训练BERT。相反,使用两个新的无监督预测任务对BERT进行预训练。
1 预测词
双向 采用MLM(mask language model)[只预测masked words而不是重建整个句子]
为了达到真正的bidirectional的LM的效果,作者创新性的提出了Masked LM,但是缺点是如果常常把一些词mask起来,未来的fine tuning过程中模型有可能没见过这些词。这个量积累下来还是很大的。因为作者在他的实现中随机选择了句子中15%的WordPiece tokens作为要mask的词。
为了解决这个问题,作者在做mask的时候,
80%的时间真的用[MASK]取代被选中的词。比如 my dog is hairy -> my dog is [MASK]
10%的时间用一个随机词取代它:my dog is hairy -> my dog is apple
10%的时间保持不变: my dog is hairy -> my dog is hairy
为什么要以一定的概率保持不变呢? 这是因为刚才说了,如果100%的时间都用[MASK]来取代被选中的词,那么在fine tuning的时候模型会有一些没见过的词。那么为啥要以一定的概率使用随机词呢?这是因为Transformer要保持对每个输入token分布式的表征,否则Transformer很可能会记住这个[MASK]就是"hairy"。至于使用随机词带来的负面影响,文章中说了,所有其他的token(即非"hairy"的token)共享15%*10% = 1.5%的概率,其影响是可以忽略不计的
2 预测下一句
与从左到右的语言模型预训练不同,MLM 目标允许表征融合左右两侧的语境,从而预训练一个深度双向 Transformer。除了遮蔽语言模型之外,本文作者还引入了一个“下一句预测”(next sentence prediction)任务,可以和MLM共同预训练文本对的表示。
3 代码
谷歌开源的bert代码以及提供的训练好的模型https://github.com/google-research/bert
对bert进行封装,提供一个输入给出对应的向量https://github.com/lbda1/bert-as-service
基于bert的实体识别https://github.com/lbda1/BERT-NER
4 bert词向量输出https://blog.csdn.net/luoyexuge/article/details/84939755
import tensorflow as tf from bert import modeling import os import collections import six from gevent import monkey monkey.patch_all() from flask import Flask, request from gevent import pywsgi import numpy as np import json flags = tf.flags FLAGS = flags.FLAGS bert_path = 'bert_model/' //bert模型 flags.DEFINE_string( "bert_config_file", os.path.join(bert_path, 'bert_config.json'), "The config json file corresponding to the pre-trained BERT model." ) flags.DEFINE_string( "bert_vocab_file", os.path.join(bert_path, 'vocab.txt'), "The config vocab file" ) flags.DEFINE_string( "init_checkpoint", os.path.join(bert_path, 'bert_model.ckpt'), "Initial checkpoint (usually from a pre-trained BERT model)." ) app = Flask(__name__) def convert_to_unicode(text): """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" if six.PY3: if isinstance(text, str): return text elif isinstance(text, bytes): return text.decode("utf-8", "ignore") else: raise ValueError("Unsupported string type: %s" % (type(text))) elif six.PY2: if isinstance(text, str): return text.decode("utf-8", "ignore") elif isinstance(text, unicode): return text else: raise ValueError("Unsupported string type: %s" % (type(text))) else: raise ValueError("Not running on Python2 or Python 3?") def load_vocab(vocab_file): vocab = collections.OrderedDict() vocab.setdefault("blank",2) index = 0 with tf.gfile.GFile(vocab_file, "r") as reader: while True: token = convert_to_unicode(reader.readline()) if not token: break token = token.strip() vocab[token] = index index += 1 return vocab di=load_vocab(vocab_file=FLAGS.bert_vocab_file) init_checkpoint=FLAGS.init_checkpoint use_tpu=False sess=tf.Session() bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) print(init_checkpoint) is_training=False use_one_hot_embeddings=False def inputs(vectors,maxlen=10): length=len(vectors) if length>=maxlen: return vectors[0:maxlen],[1]*maxlen,[0]*maxlen else: input=vectors+[0]*(maxlen-length) mask=[1]*length+[0]*(maxlen-length) segment=[0]*maxlen return input,mask,segment input_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_ids_p") input_mask_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_mask_p") segment_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="segment_ids_p") model = modeling.BertModel( config=bert_config, is_training=is_training, input_ids=input_ids_p, input_mask=input_mask_p, token_type_ids=segment_ids_p, use_one_hot_embeddings=use_one_hot_embeddings ) restore_saver = tf.train.Saver() restore_saver.restore(sess, init_checkpoint) @app.route('/bertvectors') def response_request(): text = request.args.get('text') vectors = [di.get("[CLS]")] + [di.get(i) if i in di else di.get("[UNK]") for i in list(text)] + [di.get("[SEP]")] input, mask, segment = inputs(vectors) input_ids = np.reshape(np.array(input), [1, -1]) input_mask = np.reshape(np.array(mask), [1, -1]) segment_ids = np.reshape(np.array(segment), [1, -1]) embedding = tf.squeeze(model.get_sequence_output()) ret=sess.run(embedding,feed_dict={"input_ids_p:0":input_ids,"input_mask_p:0":input_mask,"segment_ids_p:0":segment_ids}) return json.dumps(ret.tolist(), ensure_ascii=False) if __name__ == "__main__": server = pywsgi.WSGIServer(('0.0.0.0', 19877), app) server.serve_forever()