BERT详解(4)---fine-tuning

目录

1. fine-tuning

BERT本质上是一个两段式的NLP模型。第一个阶段叫做:Pre-training,跟WordEmbedding类似,利用现有无标记的语料训练一个语言模型。第二个阶段叫做:Fine-tuning,利用预训练好的语言模型,完成具体的NLP下游任务。pre-training的训练成本很大,一般直接使用google训练好的模型,而fine-tuning成本相对较少,本文介绍如何进行fine-tuning,对应的程序为run_classifier.py,如下所示,从主函数开始

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  #   新增类,用于处理数据,加载训练数据
  processors = {                              # 【2】新增类
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "mypro": MyProcessor
  }

  tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)

  if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
    raise ValueError(
        "At least one of `do_train`, `do_eval` or `do_predict' must be True.")

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  tf.gfile.MakeDirs(FLAGS.output_dir)

  # 读取task_name, task_name实际上就是选择或自定义的processor,若自定义需要加入processors字典中
  task_name = FLAGS.task_name.lower()

  if task_name not in processors:
    raise ValueError("Task not found: %s" % (task_name))
  # 获取数据处理方法
  processor = processors[task_name]()
  # 获取标签
  label_list = processor.get_labels()

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
      ....
      ....

从程序中可以看出需要自定义一个类来处理原始数据,用于训练, 并将该类加入processors字典当中,数据处理类可以参考已经存在的类。此处我自定义一个MyProcessor类来加载原始数据, 代码如下

class MyProcessor(DataProcessor):
  """Processor for the MRPC data set (GLUE version)."""
  def __init__(self):
  # 【0】  设置语言
      self.language = "zh"                                            

  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self):
    """See base class."""
    return ["0", "1", "2", "3", "4", "5", "6", "7", "8"]     # 【1】设置标签label

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      guid = "%s-%s" % (set_type, i)
      text_a = tokenization.convert_to_unicode(line[1])         # 【2】text_a/text_b根据实际语料结构更改
      text_b = None
      if set_type == "test":
        label = "-1"
      else:
        label = tokenization.convert_to_unicode(line[0])
      examples.append(
          InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
    return examples

该类中主要定义几个函数分别进行获取训练(测试、验证)样本(get_train_examples),获取标签值(get_labels),同时也能发现,在data_dir中我们需要将数据处理成.tsv格式,训练集、开发集和测试集分别是train.tsv, dev.tsv, test.tsv,这里我们暂时只使用train.tsv和dev.tsv。另外,label在get_labels()设定,如果是二分类,则将label设定为[“0”,”1”],同时_create_examples()中,给定了如何获取guid以及如何给text_a, text_b和label赋值。

2. 总结

对于这个fine-tuning过程,我们要做的只是:

  • 准备好一个12G显存左右的GPU,没有也不用担心,可以使用谷歌免费的GPU
  • 准备好train.tsv, dev.tsv以及test.tsv
  • 新建一个跟自己task_name对应的processor,用于将train.tsv、dev.tsv以及test.tsv中的数据提取出来赋给text_a, text_b, label
  • 下载好Pre-training模型,设定好相关参数,run就完事了

.tsv文件格式如下:标签+table+句子

0 			i love china
1			a beautiful day
0 			i love shanghai

更详细介绍fine_tuning过程及运行代码见orangerfun github

3 参考

BERT源码

发布了33 篇原创文章 · 获赞 1 · 访问量 2596

猜你喜欢

转载自blog.csdn.net/orangerfun/article/details/104623610