Detailed BERT (4) --- fine-tuning

table of Contents

1. fine-tuning

NLP is a two-stage model of the BERT nature. The first stage is called: Pre-training, with WordEmbedding similar, existing unlabeled corpus to train a language model. The second stage is called: Fine-tuning, the use of pre-trained language model, NLP downstream accomplish specific tasks. the cost of training pre-training is large, typically google trained model directly, while fine-tuning is relatively low cost, this article describes how to fine-tuning, the corresponding procedure is run_classifier.pyas follows, starting from the main function

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  #   新增类,用于处理数据,加载训练数据
  processors = {                              # 【2】新增类
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "mypro": MyProcessor
  }

  tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)

  if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
    raise ValueError(
        "At least one of `do_train`, `do_eval` or `do_predict' must be True.")

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  tf.gfile.MakeDirs(FLAGS.output_dir)

  # 读取task_name, task_name实际上就是选择或自定义的processor,若自定义需要加入processors字典中
  task_name = FLAGS.task_name.lower()

  if task_name not in processors:
    raise ValueError("Task not found: %s" % (task_name))
  # 获取数据处理方法
  processor = processors[task_name]()
  # 获取标签
  label_list = processor.get_labels()

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
      ....
      ....

As can be seen from the program need to customize a class to handle raw data used for training, and the addition of such processors among the dictionary, the data processing class may refer to existing classes. Here I define a custom MyProcessorclass to load the raw data, the following code

class MyProcessor(DataProcessor):
  """Processor for the MRPC data set (GLUE version)."""
  def __init__(self):
  # 【0】  设置语言
      self.language = "zh"                                            

  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self):
    """See base class."""
    return ["0", "1", "2", "3", "4", "5", "6", "7", "8"]     # 【1】设置标签label

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      guid = "%s-%s" % (set_type, i)
      text_a = tokenization.convert_to_unicode(line[1])         # 【2】text_a/text_b根据实际语料结构更改
      text_b = None
      if set_type == "test":
        label = "-1"
      else:
        label = tokenization.convert_to_unicode(line[0])
      examples.append(
          InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
    return examples

This class defines several main functions were acquired training (testing, validation) sample ( get_train_examples), acquires tag value ( get_labels), and can be found in data_dir we need to process the data into .tsv format, the training set, collection development and test sets are train.tsv, dev.tsv, test.tsv, here we use only temporarily train.tsv and dev.tsv. Further, label () is set at get_labels, if binary, then the label is set to [ "0", "1"], while _create_examples (), a given of how to obtain and how to guid text_a, text_b and label assignment.

2. summary

For this fine-tuning process, we have to do is:

  • Ready for about a 12G memory GPU, no, do not worry, you can use Google's free GPU
  • Ready train.tsv, dev.tsv and test.tsv
  • Create a new task_name with their corresponding processor for extracting train.tsv, dev.tsv and data test.tsv out in the assigned text_a, text_b, label
  • Download Good Pre-training model, set up parameters, run to get away

.tsv file format is as follows: Label + table + sentences

0 			i love china
1			a beautiful day
0 			i love shanghai

More details fine_tuning process and run the code to see orangerfun github

3 References

BERT Source

Published 33 original articles · won praise 1 · views 2596

Guess you like

Origin blog.csdn.net/orangerfun/article/details/104623610