XLNet代码详解之预训练

XLNet源码地址https://github.com/zihangdai/xlnet

XLNET原理详解https://blog.csdn.net/weixin_37947156/article/details/93035607

XLNET中文预训练https://github.com/ymcui/Chinese-PreTrained-XLNet

原作地址http://fancyerii.github.io/2019/06/30/xlnet-codes/

目录

运行预处理数据

data_utils.py代码分析

第一部分:加载sentence-piece模型

第二部分:处理每一个文件的过程

第三部分:拼接前的预处理和拼接

第四部分:create_tfrecords函数

train_gpu.py代码

第一部分:train

第二部分:data_utils.get_input_fn

get_input_fn函数主要代码

get_dataset函数

_local_perm函数

perm_mask

第三部分:single_core_graph

get_model_fn

function_builder.get_loss

two_stream_loss

XLNetConfig

RunConfig

XLNetModel

模型简称 语料 Google下载 讯飞云下载
XLNet-mid, Chinese 中文维基+
通用数据[1]
TensorFlow 
PyTorch
TensorFlow(密码f5ux) 
PyTorch(密码vnnt)
XLNet-base, Chinese 中文维基+
通用数据[1]
暂未开放 暂未开放
  • XLNet代码结构

           

  • gpu_utils.py:ckpt模型加载、设置gpu、平均梯度和损失
  • prepro_utils.py:文本预处理、sentence-piece模型转换文本为WordPiece对应的ID
  • squard_utils.py:额外的统计数据,并绘制了额外PR曲线
  • model_utils.py:加载预训练模型、带L2权重衰减的Adam优化器
  • data_utils.py:预处理数据
  • classifier_utils.py:定义样本特征数据、序列对截断到最大长度、将单个“InputExample”转换为单个“InputFeatures”
  • modeling.py:transformer-xl、tow stream attention实现
  • xlnet.py:更高层的XLNetModel类,封装modeling的transformer
  • function_builder.py:各种用于pretrain和finetune的loss function
  • train_xxx.py:XLNet预训练
  • run_xxx.py:XLNet精调和评估

总体的依赖顺序就是:

  • Pretrain: train_xxx.py -> function_builder.py -> modeling.py -> xxx_utils.py
  • Finetune: run_xxx.py -> function_builder.py -> modeling.py -> xxx_utils.py

运行预处理数据

XLNET提供脚本来预处理数据,运行它:

python data_utils.py
--bsz_per_host=8
--num_core_per_host=1
--seq_len=128
--reuse_len=64
--input_glob=pretrain.txt
--save_dir=traindata
--num_passes=20
--bi_data=True
--sp_path=spiece.model
--mask_alpha=6
--mask_beta=1
--num_predict=21

解释这些参数表示的意义:

  • bsz_per_host: 每个host的batch大小,这里是8。因为它是多个TPU同时训练,所以可能有多个host,我们这里只有一个host。
  • num_core_per_host: 每个host的TPU的个数,我这里用CPU,只能是1。
  • seq_len: 序列长度,这里使用128
  • reuse_len:可作为内存重用的token数量,一般为seq_len的一半,这里是64
  • input_glob :输入的训练数据,可以用*这样的通配符
  • save_dir :输出目录
  • num_passes: 生成多少趟(因为随机排列,所以每次都不同)
  • bi_data :是否双向的batch,参考前面的理论部分
  • sp_path: sentencepiece的模型,模型下载后自带了一个
  • mask_alpha:一组有多少个token
  • mask_beta:每组中要屏蔽多少个token
  • num_predict : 预测多少个token

sp_path是sentencepiece的模型,如果是自己的数据,可以使用spm_train工具来训练自己的WordPiece模型。

具体参见:

1、sentencePiece 分词原理学习

2、自然语言处理之_SentencePiece分词

data_utils.py代码分析

我们首先来看怎么生成训练数据的。它的main函数会调用create_data()函数,这个函数会调用_create_data来创建Pretraining的数据。这个函数的核心代码为:

def _create_data(idx, input_paths):
  ## 第一部分:加载sentence-piece模型
  sp = spm.SentencePieceProcessor()
  sp.Load(FLAGS.sp_path)

  input_shards = []
  total_line_cnt = 0
  ## 第二部分:处理每一个文件的过程
  for input_path in input_paths:
    input_data, sent_ids = [], []
    sent_id, line_cnt = True, 0
    tf.logging.info("Processing %s", input_path)
    for line in tf.gfile.Open(input_path):
      if line_cnt % 100000 == 0:
        tf.logging.info("Loading line %d", line_cnt)
      line_cnt += 1

      if not line.strip():
        if FLAGS.use_eod:
          sent_id = not sent_id
          cur_sent = [EOD_ID]
        else:
          continue
      else:
        if FLAGS.from_raw_text:
          cur_sent = preprocess_text(line.strip(), lower=FLAGS.uncased)
          cur_sent = encode_ids(sp, cur_sent)
        else:
          cur_sent = list(map(int, line.strip().split()))

      input_data.extend(cur_sent)
      sent_ids.extend([sent_id] * len(cur_sent))
      sent_id = not sent_id

    tf.logging.info("Finish with line %d", line_cnt)
    if line_cnt == 0:
      continue

    input_data = np.array(input_data, dtype=np.int64)
    sent_ids = np.array(sent_ids, dtype=np.bool)

    total_line_cnt += line_cnt
    input_shards.append((input_data, sent_ids))

  tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)

  tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")

  filenames, num_batch = [], 0

  # 随机打乱输入碎片(使用固定但不同的随机种子)
  np.random.seed(100 * FLAGS.task + FLAGS.pass_id)

  perm_indices = np.random.permutation(len(input_shards))
  tf.logging.info("Using perm indices %s for pass %d",
                  perm_indices.tolist(), FLAGS.pass_id)

  input_data_list, sent_ids_list = [], []
  prev_sent_id = None
  ## 第三部分:拼接前的预处理和拼接
  # 把不同文件的数据拼成一个大的向量前的预处理
  # 主要是处理sent_ids
  for perm_idx in perm_indices:
    input_data, sent_ids = input_shards[perm_idx]
    # make sure the `send_ids[0] == not prev_sent_id`
    if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
      sent_ids = np.logical_not(sent_ids)

    # append to temporary list
    input_data_list.append(input_data)
    sent_ids_list.append(sent_ids)

    # update `prev_sent_id`
    prev_sent_id = sent_ids[-1]

  # 最终得到一个大的向量
  input_data = np.concatenate(input_data_list)
  sent_ids = np.concatenate(sent_ids_list)
  ## 第四部分:调用create_tfrecords函数
  file_name, cur_num_batch = create_tfrecords(
      save_dir=tfrecord_dir,
      basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id),
      data=[input_data, sent_ids],
      bsz_per_host=FLAGS.bsz_per_host,
      seq_len=FLAGS.seq_len,
      bi_data=FLAGS.bi_data,
      sp=sp,
  )

  filenames.append(file_name)
  num_batch += cur_num_batch

  record_info = {
      "filenames": filenames,
      "num_batch": num_batch
  }

  return record_info

分解为如下几个部分(见代码注释):

  • 第一部分:加载sentence-piece模型
  • 第二部分:处理每一个文件的过程

  • 第三部分:拼接前的预处理和拼接

  • 第四部分:调用create_tfrecords函数

第一部分:加载sentence-piece模型

前两行代码加载sentencepiece 。

第二部分:处理每一个文件的过程

这个过程读取每一个文件的每一行,然后使用sp切分成WordPiece,然后变成id,放到数组input_data里。另外还有一个sent_ids,用来表示句子。对于每一个文件(我们这里只有一个),最终是为了得到”input_data, sent_ids = [], []”两个list。

input_data里是放到这个文件的每一个WordPiece对应的ID,而sent_ids用于判断句子的边界。见下面的例子

input_data=[19  5372 19317  6014   111 17634  2611    19  1084   111   420   152
    25  6096    26  8888]
sent_ids=[ True  True  True  True  True  True  True False False False False False
 False False False False]

因为第一个句子是”我的空调漏水怎么办”,使用sp切分后变成”['▁', '我的', '空调', '漏', '水', '怎么', '办']”,最后变成ID得到[19, 5372, 19317, 6014, 111, 17634, 2611]。而sent_ids是[ True  True  True  True  True  True  True]。

因此input_data是每一个WordPiece对应的ID的数组,而sent_ids可以判断哪些ID是属于一个句子的,也就是sent_ids通过交替的True和False来告诉我们句子的边界,比如前面的sent_ids的前7个为True,因此我们可以知道前7个WordPiece属于第一个句子,而后面的9个连续False告诉我们第二个句子有9个WordPiece。那么如果第三个句子有6个WordPiece,则我们可以猜测后面应该出现连续6个True。

WordPiece的一种主要的实现方式叫做BPE(Byte-Pair Encoding)双字节编码。BPE的过程可以理解为把一个单词再拆分,使得我们的此表会变得精简,并且寓意更加清晰。比如"loved","loving","loves"这三个单词。其实本身的语义都是“爱”的意思,但是如果我们以单词为单位,那它们就算不一样的词,在英语中不同后缀的词非常的多,就会使得词表变的很大,训练速度变慢,训练的效果也不是太好。BPE算法通过训练,能够把上面的3个单词拆分成"lov","ed","ing","es"几部分,这样可以把词的本身的意思和时态分开,有效的减少了词表的数量。

WordPiece或者BPE这么好,我们是不是哪里都能这么用呢?其实在我们的中文中不是很适用。首先我们的中文不像英文或者其他欧洲的语言一样通过空格分开,我们是连续的。其次我们的中文一个字就是一个最小的单元,无法在拆分的更小了。在中文中一般的处理方式是两中,分词和分字。理论上分词要比分字好,因为分词更加细致,语义分的更加开。分字简单,效率高,词表也很小,常用字就3000左右。

如果一个WordPiece以”▁”开始,则表明它是一个词的开始,而不以”▁”开始的表明它是接着前面的。

此外上面的代码还有处理空行,用于表示一个新的Document的开始(取决于选项FLAGS.use_eod),则会加一个特殊的Token EOD_ID。而段落的结束是使用表示,下面是一些特殊的符号及其ID:

special_symbols = {
    "<unk>"  : 0,
    "<s>"    : 1,
    "</s>"   : 2,
    "<cls>"  : 3,
    "<sep>"  : 4,
    "<pad>"  : 5,
    "<mask>" : 6,
    "<eod>"  : 7,
    "<eop>"  : 8,
}

第三部分:拼接前的预处理和拼接

通过前面的代码,我们可以把每一个文件都变成一个(input_data, sent_ids)pair,放到input_shards这个list里。但是我们还需要把不同文件的(input_data, sent_ids)拼接成更大的一个(input_data, sent_ids)。input_data可以直接拼接,但是sent_ids不行,为什么呢?我们假设第一个文件有3个句子,因此它的sent_ids类似[True,True,False,False,True,True],而第二个文件是两个句子,[True,False],那么直接拼起来就变成[True,True,False,False,True,True,True,False],拼接后本来应该是5个句子,但是现在变成了4个!

因为第一个文件是True结尾,但是第二个是True开始,因此我们需要把第二个文件的True和False反过来,这就是预处理的代码,关键的代码都有注释:

for perm_idx in perm_indices:
    input_data, sent_ids = input_shards[perm_idx]
    # 如果上一个文件的最后的sent_id和这个文件的开始的sent_id相同
    # 那么就得把当前这个文件的sent_id反过来
    if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
      sent_ids = np.logical_not(sent_ids)
 
    # append到临时的list
    input_data_list.append(input_data)
    sent_ids_list.append(sent_ids)
 
    # 更新 `prev_sent_id`
    prev_sent_id = sent_ids[-1]

最后拼接成两个大的向量:

    input_data = np.concatenate(input_data_list)
    sent_ids = np.concatenate(sent_ids_list)

第四部分:create_tfrecords函数

def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
                     bi_data, sp):
  data, sent_ids = data[0], data[1]

  num_core = FLAGS.num_core_per_host
  bsz_per_core = bsz_per_host // num_core

  if bi_data:
    assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0
    fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids)

    fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1)
    fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1)

    bwd_data = fwd_data[:, :, :, ::-1]
    bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1]

    data = np.concatenate(
        [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1)
    sent_ids = np.concatenate(
        [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1)
  else:
    data, sent_ids = batchify(data, bsz_per_host, sent_ids)

  tf.logging.info("Raw data shape %s.", data.shape)

  file_name = format_filename(
      prefix=basename,
      bsz_per_host=bsz_per_host,
      seq_len=seq_len,
      bi_data=bi_data,
      suffix="tfrecords",
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      reuse_len=FLAGS.reuse_len,
      uncased=FLAGS.uncased,
      fixed_num_predict=FLAGS.num_predict
  )
  save_path = os.path.join(save_dir, file_name)
  record_writer = tf.python_io.TFRecordWriter(save_path)
  tf.logging.info("Start writing %s.", save_path)

  num_batch = 0
  reuse_len = FLAGS.reuse_len

  # [sep] x 2 + [cls]
  assert reuse_len < seq_len - 3

  data_len = data.shape[1]
  sep_array = np.array([SEP_ID], dtype=np.int64)
  cls_array = np.array([CLS_ID], dtype=np.int64)

  i = 0
  while i + seq_len <= data_len:
    if num_batch % 500 == 0:
      tf.logging.info("Processing batch %d", num_batch)

    all_ok = True
    features = []
    for idx in range(bsz_per_host):
      inp = data[idx, i: i + reuse_len]
      tgt = data[idx, i + 1: i + reuse_len + 1]

      results = _split_a_and_b(
          data[idx],
          sent_ids[idx],
          begin_idx=i + reuse_len,
          tot_len=seq_len - reuse_len - 3,
          extend_target=True)
      if results is None:
        tf.logging.info("Break out with seq idx %d", i)
        all_ok = False
        break

      # unpack the results
      (a_data, b_data, label, _, a_target, b_target) = tuple(results)

      # sample ngram spans to predict
      reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1
      if FLAGS.num_predict is None:
        num_predict_0 = num_predict_1 = None
      else:
        num_predict_1 = FLAGS.num_predict // 2
        num_predict_0 = FLAGS.num_predict - num_predict_1
      mask_0 = _sample_mask(sp, inp, reverse=reverse,
                            goal_num_predict=num_predict_0)
      mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
                                                sep_array, cls_array]),
                            reverse=reverse, goal_num_predict=num_predict_1)

      # concatenate data
      cat_data = np.concatenate([inp, a_data, sep_array, b_data,
                                 sep_array, cls_array])
      seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
                [1] * b_data.shape[0] + [1] + [2])
      assert cat_data.shape[0] == seq_len
      assert mask_0.shape[0] == seq_len // 2
      assert mask_1.shape[0] == seq_len // 2

      # the last two CLS's are not used, just for padding purposes
      tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
      assert tgt.shape[0] == seq_len

      is_masked = np.concatenate([mask_0, mask_1], 0)
      if FLAGS.num_predict is not None:
        assert np.sum(is_masked) == FLAGS.num_predict

      feature = {
          "input": _int64_feature(cat_data),
          "is_masked": _int64_feature(is_masked),
          "target": _int64_feature(tgt),
          "seg_id": _int64_feature(seg_id),
          "label": _int64_feature([label]),
      }
      features.append(feature)

    if all_ok:
      assert len(features) == bsz_per_host
      for feature in features:
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        record_writer.write(example.SerializeToString())
      num_batch += 1
    else:
      break

    i += reuse_len

  record_writer.close()
  tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)

  return save_path, num_batch

首先看前面部分的代码:

 data, sent_ids = data[0], data[1]
 
  num_core = FLAGS.num_core_per_host
  bsz_per_core = bsz_per_host // num_core
 
  if bi_data:
    assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0
    fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids)
 
    fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1)
    fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1)
 
    bwd_data = fwd_data[:, :, :, ::-1]
    bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1]
 
    data = np.concatenate(
        [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1)
    sent_ids = np.concatenate(
        [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1)
  else:
    data, sent_ids = batchify(data, bsz_per_host, sent_ids)
 
  tf.logging.info("Raw data shape %s.", data.shape)

在阅读这部分代码前我们先来了解它的作用。这个函数的前面部分的作用是整个语料库(一个很长的data和对应sent_ids)分成batch。比如假设data为:

1 2 3 4 .... 1001

并且batch为8,bi_data为True(两个方向),则上面的代码首先把1001个数据分成8/2=4个部分,不能整除的扔掉,因此变成:

1 2 ... 250
251 252 ... 500
501 502 ... 750
751 752 ... 1000

然后加上反过来的数据:


250 ... 2 1
500 ... 252 251
750 ... 502 501
1000 ... 752 751

最终变成:


1 2 ... 250
251 252 ... 500
501 502 ... 750
751 752 ... 1000
250 ... 2 1
500 ... 252 251
750 ... 502 501
1000 ... 752 751

它主要会用到batchify函数为:


def batchify(data, bsz_per_host, sent_ids=None):
  num_step = len(data) // bsz_per_host
  data = data[:bsz_per_host * num_step]
  data = data.reshape(bsz_per_host, num_step)
  if sent_ids is not None:
    sent_ids = sent_ids[:bsz_per_host * num_step]
    sent_ids = sent_ids.reshape(bsz_per_host, num_step)
 
  if sent_ids is not None:
    return data, sent_ids
  return data

我们假设输入data是[3239,],并且bsz_per_host为4,则每个batch得到3239//4=3236/4=809个steps。3239去掉不能整除的最后3个,就是3236个ID。然后把它resahpe成(4, 809),sent_ids也是类似的操作。

生成Pretraining的数据:

A和B有两种关系,第一种它们是连续的上下文;第二种B是随机在data中选择的句子。

接下来是一个大的for循环:

  while i + seq_len <= data_len:
    ....
    i += reuse_len

上面的大while循环就是每次移动64(reuse_len),首先固定64个作为cache。然后从i+reuse_len位置开始不断寻找句子,直到这些句子的Token数大于61(128-64-3)。比如:

64  65-75 76-90 91-128

上面的例子找到3个句子,这三个句子的Token数大于61了。然后以50%的概率选择如下两种方案生成A和B:

  • A和B是连续的,因此从3个句子里随机的选择前面一部分作为A,剩下的作为B。比如有可能前两个句子是A,后一个是B。

  • A和B不连续,因此这3个句子随机选一部分作为A,比如前两个句子,接着随机的从整个data里寻找一部分作为B。

当然上面只是大致的思路,细节很多:比如这三个句子的长度超过61了,那么需要从A或者B里删除一部分;比如随机的从data里选择B,很可能B是句子的中间,那么需要向前后两个方向”扩充”B(当然同时要从A的尾部删除相应的个数的Token)。这里就不介绍了,读者知道它的作用后阅读代码就会比较容易了。

接下来就是对这128个Token进行”Mask”了,这是通过_sample_mask函数实现的。它首先对前64个memory进行Mask,然后对后面64个也进行Mask。_sample_mask的代码比较细节,我这里只介绍它的大致思路。

  • 首先随机选择n-gram的n,n的范围是[1,5],这里假设n为2
  • 然后计算上下文 “ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta” 这里为2*6=12
  • 然后随机的ctx_size(12)切分成l_ctx和r_ctx,假设为5和7
  • 然后下标后移5(l_ctx),因为后移5之后可能不是一个词,因此持续后移找到n-gram开始的位置
  • 寻找n-gram开始的位置寻找n个词(n个词可能多于n个Token)
  • 然后从n-gram介绍的地方后移7(r_ctx)个位置,并且持续后移直到遇到词的开始(以”▁”开始的Token)

这样就找到了一个被Mask的n-gram以及它的左右(大致)l_ctx和r_ctx个Token。如果Mask的Token到达我们的预期(goal_num_predict)就退出,否则从结束的下标开始持续这个过程。最终我们需要得到的数据是feature,下面是一个feature的示例值:

input: [   52    27    18    89  3833     9    52    27    18   205  3833    21
    77    18   239    20    18 11636     9     8   245 11636     9     7
   245  2402  3091   193     9     7    52    27    18    89  3833     9
    52    27    18   205  3833    21    77    18   239    20    18 11636
     9     8   245 11636     9     7   245  2402  3091   193     9     7
    52    27    18    89  3833     9    52    27    18   205  3833    21
    77    18   239    20    18 11636     9     8   245 11636     9     7
   245  2402  3091   193     9     7    52    27    18    89  3833     9
    52    27    18   205  3833    21    77    18   239    20    18 11636
     9     8   245 11636     9     7   245  2402  3091   193     9     4
    52    27    18    89  3833     9     4     3]
 
tgt: [   27    18    89  3833     9    52    27    18   205  3833    21    77
    18   239    20    18 11636     9     8   245 11636     9     7   245
  2402  3091   193     9     7    52    27    18    89  3833     9    52
    27    18   205  3833    21    77    18   239    20    18 11636     9
     8   245 11636     9     7   245  2402  3091   193     9     7    52
    27    18    89  3833     9    52    27    18   205  3833    21    77
    18   239    20    18 11636     9     8   245 11636     9     7   245
  2402  3091   193     9     7    52    27    18    89  3833     9    52
    27    18   205  3833    21    77    18   239    20    18 11636     9
     8   245 11636     9     7   245  2402  3091   193     9     7    52
    27    18    89  3833     9    52     3     3]
 
is_masked: [False False False False False  True  True False False False False False
 False False False False False  True  True  True  True  True False  True
  True False False False False False False False False False False False
  True False False False False False False False False False False False
 False False  True False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False  True False False False False False False
 False False False False False False  True  True  True  True  True False
 False False False False False False False False False False  True  True
 False False False False False False False False False False False False
 False  True  True False False False False False]
 
seg_id: [0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2]
 
label: 1

这些变量的含义是:

  • input:长度为128的输入,前64个是mem,后面64个是A和B(加上2个SEP一个CLS)
  • tgt:长度128,除了最后两个是CLS,前面126是input对应的下一个目标值
  • label:1表示A和B是连续的句子
  • seg_id:表示输入input的segment,mem+A+SEP是0,B+SEP是1,最后一个CLS是2
  • is_masked:表示这128个里哪些位置是Mask的

最终这5个变量都会作为features放到一个tf.train.Example写到TFRecord文件里。

train_gpu.py代码

python train_gpu.py 
   --record_info_dir=tfrecords        #指向包含“record_info-lm.json”的本地目录的路径
   --train_batch_size=8 
   --seq_len=128 
   --reuse_len=64  
   --mem_len=96                       #缓冲的step数量                          
   --perm_size=32 
   --n_layer=6 
   --d_model=1024 
   --d_embed=1024 
   --n_head=16 
   --d_head=64 
   --d_inner=4096                     #Dimension of inner hidden size in positionwise feed-forward
   --untie_r=True                     #Untie r_w_bias and r_r_bias
   --mask_alpha=6 
   --mask_beta=1 
   --num_predict=21 
   --model_dir=mymodel
   --uncased=true 
   --num_core_per_host=1

第一部分:train

训练主要是调用函数train,它的主要代码为:

def train(ps_device):
  ##### 得到input function和model function

  train_input_fn, record_info_dict = data_utils.get_input_fn(
      tfrecord_dir=FLAGS.record_info_dir,
      split="train",
      bsz_per_host=FLAGS.train_batch_size,
      seq_len=FLAGS.seq_len,
      reuse_len=FLAGS.reuse_len,
      bi_data=FLAGS.bi_data,
      num_hosts=1,
      num_core_per_host=1, # set to one no matter how many GPUs
      perm_size=FLAGS.perm_size,
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      uncased=FLAGS.uncased,
      num_passes=FLAGS.num_passes,
      use_bfloat16=FLAGS.use_bfloat16,
      num_predict=FLAGS.num_predict)

  # for key, info in record_info_dict.items():
  tf.logging.info("num of batches {}".format(record_info_dict["num_batch"]))

  ##### 创建输入tensors(placeholder)
  # 训练的batch_size平均分配到多个设备(GPU)上
  bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

  params = {
      "batch_size": FLAGS.train_batch_size # the whole batch
  }
  # 调用前面的input函数得到Dataset
  train_set = train_input_fn(params)
  # 得到一个example
  example = train_set.make_one_shot_iterator().get_next()

  if FLAGS.num_core_per_host > 1:
    examples = [{} for _ in range(FLAGS.num_core_per_host)]
    for key in example.keys():
      vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
      for device_id in range(FLAGS.num_core_per_host):
        examples[device_id][key] = vals[device_id]
  else:
    examples = [example]

  ##### 创建计算图的代码
  tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []

  for i in range(FLAGS.num_core_per_host):
    reuse = True if i > 0 else None
    with tf.device(assign_to_gpu(i, ps_device)), \
        tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

      # The mems for each tower is a dictionary
      mems_i = {}
      if FLAGS.mem_len:
        # 创建mems,这是Transformer-XL里的思想,保留之前的96个Token的上下文
        mems_i["mems"] = create_mems_tf(bsz_per_core)
        
      # 这是创建计算图的核心代码,我们后面会详细分析
      # loss_i是loss;new_mems_i是个list,表示每一层的新的96个Token的隐状态
      # grads_and_vars_i是梯度和变量,因为需要多个GPU数据并行,
      # 所以需要自己来平均所有GPU的梯度然后用它更新参数(而不能用Optimizer)
      loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
          is_training=True,
          features=examples[i],
          mems=mems_i)

      tower_mems.append(mems_i)
      tower_losses.append(loss_i)
      tower_new_mems.append(new_mems_i)
      tower_grads_and_vars.append(grads_and_vars_i)

  ## 如果有多个GPU,那么需要求出平均的梯度,我们这里不分析 
  if len(tower_losses) > 1:
    loss = tf.add_n(tower_losses) / len(tower_losses)
    grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
  else:
    loss = tower_losses[0]
    grads_and_vars = tower_grads_and_vars[0]

  ## 根据grads_and_vars得到训练的Opertaion,此外还会返回learning_rate和gnorm
  train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
      grads_and_vars=grads_and_vars)
  global_step = tf.train.get_global_step()

  ##### 训练的循环
  # 首先初始化mems
  tower_mems_np = []
  for i in range(FLAGS.num_core_per_host):
    mems_i_np = {}
    for key in tower_mems[i].keys():
      mems_i_np[key] = initialize_mems_np(bsz_per_core)
    tower_mems_np.append(mems_i_np)

  # 保存checkpoint的Saver
  saver = tf.train.Saver()
  # GPUOptions,允许它动态的使用更多GPU内存
  gpu_options = tf.GPUOptions(allow_growth=True)
  # 从之前的checkpoint恢复参数
  model_utils.init_from_checkpoint(FLAGS, global_vars=True)

  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
      gpu_options=gpu_options)) as sess:
    # 允许初始化所有变量的操作
    sess.run(tf.global_variables_initializer())
    # 需要运行的ops:包括loss,新的mem,global_step,gnorm,learning_rate和train_op
    fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]

    total_loss, prev_step = 0., -1
    while True:
      feed_dict = {}
      # 用上一个数据的96个mem作为当前的mem
      for i in range(FLAGS.num_core_per_host):
        for key in tower_mems_np[i].keys():
          for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
            feed_dict[m] = m_np
      # 进行一次随机梯度下降
      fetched = sess.run(fetches, feed_dict=feed_dict)
      # 拿到loss;新的mems和当前的steps
      loss_np, tower_mems_np, curr_step = fetched[:3]
      total_loss += loss_np

      if curr_step > 0 and curr_step % FLAGS.iterations == 0:
        curr_loss = total_loss / (curr_step - prev_step)
        tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            curr_step, fetched[-3], fetched[-2],
            curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
        total_loss, prev_step = 0., curr_step

      if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.logging.info("Model saved in path: {}".format(save_path))

      if curr_step >= FLAGS.train_steps:
        break

如果忽略多设备(GPU)训练的细节,train的代码结构其实并不复杂,它大致可以分为3部分:

  • 调用data_utils.get_input_fn得到train_input_fn

  • 调用single_core_graph构造XLNet网络

  • 使用session运行fetches进行训练

我们在计算当前句子时会用到之前的96个Token作为上下文,因此需要创建mems:

def create_mems_tf(bsz_per_core):
  mems = [tf.placeholder(dtype=tf.float32,
                         shape=[FLAGS.mem_len, bsz_per_core, FLAGS.d_model])
          for layer in range(FLAGS.n_layer)]

  return mems

在此它是创建了长度为n_layer的list,每一个元素的shape是[96, 8, 1024],96是mem的长度,8是batch_size,1024是隐单元个数。

第二部分:data_utils.get_input_fn

接下面我们详细的来介绍train函数处理Dataset输入以及构建计算图的代码。

get_input_fn函数主要代码

这个函数的返回值是Estimator的一个函数和一个record_info对象,这个record_info对象记录读取的TFRecord的num_batch和file_names等数据;而第一个函数执行后的返回值是一个tf.data.Dataset,这是Tensorflow标准的输入方式。

返回值知道了,我们再来看这个函数的参数:

  • tfrecord_dir 训练数据目录,如果有多个用逗号”,”分开,这里是’traindata/tfrecords’,参考第一部分。
  • split 这里是”train”表示训练数据。
  • bsz_per_host 每个host(机器)的batch大小,这里是8。
  • seq_len 句子长度,这里是128。
  • reuse_len 句子中前面不变的部分,这里是64,请参考第一部分。
  • bi_data 是否双向的模型,这里是True。
  • num_hosts=1 有多少台机器,这里是1。
  • num_core_per_host=1 每天机器有多少个设备(GPU),这里是1。
  • perm_size=None, 排列打散的长度,这里是32,后面会讲到。
  • mask_alpha=None, 这里是6
  • mask_beta=None, 这里是1
  • uncased=False, 是否不区分大小,这里是True,表示不区分大小写
  • num_passes=None, 训练的趟数,这里是1表示这是第一趟训练
  • use_bfloat16=False, 是否使用16位表示的浮点数,这里是False
  • num_predict=None: 用于预测的Token的个数,这里是21
def get_input_fn(
    tfrecord_dir,
    split,
    bsz_per_host,
    seq_len,
    reuse_len,
    bi_data,
    num_hosts=1,
    num_core_per_host=1,
    perm_size=None,
    mask_alpha=None,
    mask_beta=None,
    uncased=False,
    num_passes=None,
    use_bfloat16=False,
    num_predict=None):
  # 把所有的record info合并成一个
  record_glob_base = format_filename(
      prefix="record_info-{}-*".format(split),
      bsz_per_host=bsz_per_host,
      seq_len=seq_len,
      bi_data=bi_data,
      suffix="json",
      mask_alpha=mask_alpha,
      mask_beta=mask_beta,
      reuse_len=reuse_len,
      uncased=uncased,
      fixed_num_predict=num_predict)

  record_info = {"num_batch": 0, "filenames": []}

  tfrecord_dirs = tfrecord_dir.split(",")
  tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
  
  # 变量目录,这些目录是用逗号分开的 
  for idx, record_dir in enumerate(tfrecord_dirs):
    record_glob = os.path.join(record_dir, record_glob_base)
    tf.logging.info("[%d] Record glob: %s", idx, record_glob)

    record_paths = sorted(tf.gfile.Glob(record_glob))
    tf.logging.info("[%d] Num of record info path: %d",
                    idx, len(record_paths))

    cur_record_info = {"num_batch": 0, "filenames": []}
    # 遍历record_info-train-*.bsz-8.seqlen-128.reuse-64.uncased.bi.alpha-6.beta-1.fnp-21.json
    # 能匹配的所有json文件,这里只匹配了:
    # record_info-train-0-0.bsz-8.seqlen-128.reuse-64.uncased.bi.alpha-6.beta-1.fnp-21.json
    for record_info_path in record_paths:
      if num_passes is not None:
        # record_info_path为record_info-train-0-0.bsz-8.seqlen-128.....
        # record_info-train-0-0代表这是pretraining的训练数据,最后一个0代表这是它的pass_id
        # 如果pass_id >= num_passes,那么就会跳过,比如某个数据是record_info-train-0-3,
        # 则如果当前训练的趟数num_passes是1,2,3,都会跳过,只有当num_passes大于3时才会训练这个文件
        record_info_name = os.path.basename(record_info_path)
        fields = record_info_name.split(".")[0].split("-")
        pass_id = int(fields[-1])
        if len(fields) == 5 and pass_id >= num_passes:
          tf.logging.info("Skip pass %d: %s", pass_id, record_info_name)
          continue

      with tf.gfile.Open(record_info_path, "r") as fp:
        # 打开这个json文件
        info = json.load(fp)
        if num_passes is not None:
          eff_num_passes = min(num_passes, len(info["filenames"]))
          ratio = eff_num_passes / len(info["filenames"])
          cur_record_info["num_batch"] += int(info["num_batch"] * ratio)
          cur_record_info["filenames"] += info["filenames"][:eff_num_passes]
        else:
          cur_record_info["num_batch"] += info["num_batch"]
          cur_record_info["filenames"] += info["filenames"]

    # 根据json文件找到对应的tfrecord文件,放到`cur_record_info`
    new_filenames = []
    for filename in cur_record_info["filenames"]:
      basename = os.path.basename(filename)
      new_filename = os.path.join(record_dir, basename)
      new_filenames.append(new_filename)
    cur_record_info["filenames"] = new_filenames

    tf.logging.info("[Dir %d] Number of chosen batches: %s",
                    idx, cur_record_info["num_batch"])
    tf.logging.info("[Dir %d] Number of chosen files: %s",
                    idx, len(cur_record_info["filenames"]))
    tf.logging.info(cur_record_info["filenames"])

    # 把`cur_record_info`放到全局的`record_info`里
    record_info["num_batch"] += cur_record_info["num_batch"]
    record_info["filenames"] += cur_record_info["filenames"]

  tf.logging.info("Total number of batches: %d",
                  record_info["num_batch"])
  tf.logging.info("Total number of files: %d",
                  len(record_info["filenames"]))
  tf.logging.info(record_info["filenames"])

  # 返回的input function,调用这个函数后就会得到Dataset对象。
  def input_fn(params):
    # 一个host(机器)的batch大小 = 每个设备(GPU)的batch大小 * 设备的个数
    assert params["batch_size"] * num_core_per_host == bsz_per_host
    # 使用get_datset函数构造Dataset对象,后面我们会详细介绍这个函数
    dataset = get_dataset(
        params=params,
        num_hosts=num_hosts,
        num_core_per_host=num_core_per_host,
        split=split,
        file_names=record_info["filenames"],
        num_batch=record_info["num_batch"],
        seq_len=seq_len,
        reuse_len=reuse_len,
        perm_size=perm_size,
        mask_alpha=mask_alpha,
        mask_beta=mask_beta,
        use_bfloat16=use_bfloat16,
        num_predict=num_predict)

    return dataset

  return input_fn, record_info

这个函数主要的代码就是遍历目录下的json文件,然后找到对应的tfrecord文件,放到record_info里,同时返回input_fn函数,如果执行input_fn函数,则它会调用get_dataset函数返回Dataset,下面我们来看get_dataset函数。

get_dataset函数

def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
                num_batch, seq_len, reuse_len, perm_size, mask_alpha,
                mask_beta, use_bfloat16=False, num_predict=None):

  bsz_per_core = params["batch_size"]
  # 我们这里值考虑一台服务器(host_id=0)的情况
  if num_hosts > 1:
    host_id = params["context"].current_host
  else:
    host_id = 0

  #### parse tfrecord的函数
  def parser(record):
    # 这个可以参考第一部分生成数据的内容
    # input是输入的句子的id,长度为seq_len=128
    # target是预测的id序列,长度也是seq_len
    # seg_id 64个reused的Token,A和A后的SEP是0,B和B后的SEP是1,最后的CLS是2
    # label 如果两个句子是连续的,那么是1,否则是0
    # is_masked表示某个Token是否是masked(用于预测)
    record_spec = {
        "input": tf.FixedLenFeature([seq_len], tf.int64),
        "target": tf.FixedLenFeature([seq_len], tf.int64),
        "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
        "label": tf.FixedLenFeature([1], tf.int64),
        "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
    }

    # parse
    example = tf.parse_single_example(
        serialized=record,
        features=record_spec)

    inputs = example.pop("input")
    target = example.pop("target")
    is_masked = tf.cast(example.pop("is_masked"), tf.bool)

    non_reuse_len = seq_len - reuse_len
    assert perm_size <= reuse_len and perm_size <= non_reuse_len
    
    # 先处理前64(reuse),后面会详细介绍这个函数
    # 现在我们知道它的输入是inputs的前64个,target的前64个,is_masked的前64个
    # perm_size是32,reuse_len是64。
    # 它返回的是:
    # 1. perm_mask,64x64,表示经过重新排列后第i个token能否attend to 第j个token,1表示不能attend
    # 2. target,64,表示真实的目标值,之前生成的target是预测下一个词,但是XLNet是预测当前词
    # 3. target_mask,64,哪些地方是Mask的(需要预测的)
    # 4. input_k, 64,content stream的初始值
    # 5. input_q, 64, 哪些位置是需要计算loss的,如果不计算loss,也就不计算Query Stream。
    perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
        inputs[:reuse_len],
        target[:reuse_len],
        is_masked[:reuse_len],
        perm_size,
        reuse_len)

    perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
        inputs[reuse_len:],
        target[reuse_len:],
        is_masked[reuse_len:],
        perm_size,
        non_reuse_len)
    # tf.ones(reuse_len, non_reuse_len)表示前64个reuse的不能attend to 后面64个(1表示不能attend)
    # concat起来就变成(reuse_len=64, 128),perm_mask_0(i,j)表示i能不能attend to j(128)
    perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
                            axis=1)
    # tf.zeros(non_reuse_len, reuse_len)表示后面的64个可以attend to 前面的64(reuse_len)个
    # concat变成(64,128),perm_mask_0(i,j)表示i(后面的64个)能不能attent to j(128)
    perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
                            axis=1)
    # 把perm_mask_0和perm_mask_1 concat起来变成(128,128)
    # perm_mask(i,j)表示i(128)能不能attend to j(128)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
    # target也concat变成(128,)
    target = tf.concat([target_0, target_1], axis=0)
    # target_mask也concat成(128,)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
    # input_k也concat
    input_k = tf.concat([input_k_0, input_k_1], axis=0)
    # input_q也concat
    input_q = tf.concat([input_q_0, input_q_1], axis=0)

    if num_predict is not None:
      # indices是[0,1,...,127]
      indices = tf.range(seq_len, dtype=tf.int64)
      # target_mask中1表示MASK的值,这里把它变成boolean
      bool_target_mask = tf.cast(target_mask, tf.bool)
      # 找到Mask对应的下标,比如MASK的Token的下标是2和3,那么bool_target_mask=[F,F,T,T,...]
      # tf.boolean_mask函数返回indices里为True的值,因此返回[2,3]
      indices = tf.boolean_mask(indices, bool_target_mask)

      # 因为随机抽样的MASK可能是CLS/SEP,这些是不会被作为预测值的,因此
      # 我们之前生成的数据有num_predict(21)个需要预测的,但实际需要预测的只有actual_num_predict
      # 所以还需要padding num_predict - actual_num_predict个。
      actual_num_predict = tf.shape(indices)[0]
      pad_len = num_predict - actual_num_predict

      # target_mapping
      # 假设indices=[2,3]的话
      # 则target_mapping就变成 [[0,0,1,0,....],[0,0,0,1,....]]
      # 也就是把2和3用one-hot的方法来表示
      target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
      # padding的部分也表示成向量,但是它是"zero-hot"的表示。
      paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
      # concat成[num_redict, seq_len]的向量
      # target_mapping(i,j) = 1 表示第i个要预测的Token的目标(真实)值是j
      target_mapping = tf.concat([target_mapping, paddings], axis=0)
      # 其实不reshape也是可以的。因为除非代码有bug,否则target_mapping就是[num_redict, seq_len]
      # reshape的好处是显式的说明target_mapping的shape,调试方便一点。
      # 读者可能会问,pad_len = num_predict - actual_num_predict,然后我又
      # 把padding的和actual_num_predict的concat起来,为什么TF不知道它的shape呢?
      # 因为TF是一种静态图,pad_len只是一个Operation,还没有执行,TF并不知道它的值。
      example["target_mapping"] = tf.reshape(target_mapping,
                                             [num_predict, seq_len])

      ##### target
      # 拿到target的Token ID
      target = tf.boolean_mask(target, bool_target_mask)
      # 同样需要padding
      paddings = tf.zeros([pad_len], dtype=target.dtype)
      target = tf.concat([target, paddings], axis=0)
      example["target"] = tf.reshape(target, [num_predict])

      ##### target mask
      # 长度为21(num_predict)的向量,1表示是真正需要预测的Token;0表示是padding的,是不计算loss的
      target_mask = tf.concat(
          [tf.ones([actual_num_predict], dtype=tf.float32),
           tf.zeros([pad_len], dtype=tf.float32)],
          axis=0)
      example["target_mask"] = tf.reshape(target_mask, [num_predict])
    else:
      example["target"] = tf.reshape(target, [seq_len])
      example["target_mask"] = tf.reshape(target_mask, [seq_len])

    # reshape back to fixed shape
    example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
    example["input_k"] = tf.reshape(input_k, [seq_len])
    example["input_q"] = tf.reshape(input_q, [seq_len])

    _convert_example(example, use_bfloat16)

    for k, v in example.items():
      tf.logging.info("%s: %s", k, v)

    return example

  # Get dataset
  dataset = parse_files_to_dataset(
      parser=parser,
      file_names=file_names,
      split=split,
      num_batch=num_batch,
      num_hosts=num_hosts,
      host_id=host_id,
      num_core_per_host=num_core_per_host,
      bsz_per_core=bsz_per_core)

  return dataset

主要的代码都在parser函数里,它定义了怎么读取我们在第一部分代码里生成的tfrecord文件,然后进行排列打散(我们还没有具体介绍的_local_perm函数),最后把处理的结果放到example里。接着使用parse_files_to_dataset来得到Dataset,我们来看这个函数。

def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
                           host_id, num_core_per_host, bsz_per_core):
  # list of file pathes
  num_files = len(file_names)
  num_files_per_host = num_files // num_hosts
  my_start_file_id = host_id * num_files_per_host
  my_end_file_id = (host_id + 1) * num_files_per_host
  if host_id == num_hosts - 1:
    my_end_file_id = num_files
  file_paths = file_names[my_start_file_id: my_end_file_id]
  tf.logging.info("Host %d handles %d files", host_id, len(file_paths))

  assert split == "train"
  dataset = tf.data.Dataset.from_tensor_slices(file_paths)

  # 文件级别的shuffle
  if len(file_paths) > 1:
    dataset = dataset.shuffle(len(file_paths))

  # 注意:这里我们不能对每一个sample进行打散,这样会破坏句子的Token的顺序。
  dataset = tf.data.TFRecordDataset(dataset)

  # (zihang): 因为我们是online的随机排列,因此每次session.run的时候排列的顺序都是不同的,
  # 所以cache是没有作用的,它会导致OOM。因此我们只cache parser之前的数据(cache函数在map(parser)之前)
  # map(parser)就是对每一个tfrecord使用前面介绍的parser函数来处理
  dataset = dataset.cache().map(parser).repeat()
  dataset = dataset.batch(bsz_per_core, drop_remainder=True)
  dataset = dataset.prefetch(num_core_per_host * bsz_per_core)

  return dataset

_local_perm函数

这个函数比较难懂,这里介绍一个小技巧。如果读者对比过PyTorch和Tensorflow就会感觉使用PyTorch调试会简单很多,原因就是PyTorch是动态的计算图,因此我们把两个Tensor一相加,结果马上就出来了;但是Tensoflow是静态图,我们定义了两个Tensor相加后得到的不是它们的计算结果,而是一个Operation,我们还要用session来run它才能得到结果。如果计算简单还好,一个复杂的的函数经过一系列变换后就完全不知道它的shape是什么了。因此调试PyTorch的代码就像调试普通的Python代码一样;而调试Tensorflow的代码就像”阅读”Pyton代码一样——你看不到执行的结果。不过还有Tensorflow引入了Eager Execution。但是很多代码(包括这里的XLNet)还是习惯用之前的静态构造方法。不过没有关系,如果我们遇到一个函数看不到,那么我们可以对这个函数进行Eager Execution来调试它。比如我们如果看不到_local_perm,则我们可以这样来调试它:

import tensorflow as tf
# 开启Eager Execution
tf.enable_eager_execution()
seq_len = 16
reuse_len = 8
perm_size = 8
inputs=tf.constant([10,13,15,20,21,22,4,16,33,34,35,36,37,38,4,3])
targets=tf.constant([13,15,20,21,22,4,16,33,34,35,36,37,38,10,3,3])
is_masked=tf.constant([False, False, False, False, True, True,False,
                   False, False,False, False, False,
                   True, True, False, False])
_local_perm(inputs,targets, is_masked, perm_size, seq_len)

其中seq_len表示输入的长度,perm_size表示排列打散的长度。inputs表示输入,targets表示输出,其中4和3是特殊的SEP和CLS。is_masked表示某个输入是否是MASK(用于预测的),上面的例子里第5和6个位置、第13和14个位置都是Masked,也就是需要模型预测计算loss的位置。

接下来我们就可以增加断点单步执行从而来了解这个函数的作用了。

def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    随机的采样一种排列方式,然后创建对应的attention mask 

    参数:
      inputs: int64 Tensor shape是[seq_len],输入的id
      targets: int64 Tensor shape是[seq_len],目标值的id
      is_masked: bool Tensor shape是[seq_len],True代表用于预测
      perm_size: 最长排列的长度,具体含义参考下面的代码
      seq_len: int, 序列长度

#生成随机的排列

# 随机生成一个下标的排列
index = tf.range(seq_len, dtype=tf.int64)
index = tf.transpose(tf.reshape(index, [-1, perm_size]))
index = tf.random_shuffle(index)
index = tf.reshape(tf.transpose(index), [-1])

根据上面的输入,首先用tf.range生成[0, 15]的序列,然后第二行代码首先把它reshape成[2, 8],然后transpose成[8, 2],从而得到:

0 8
1 9
2 10
3 11
4 12
5 13
6 14
7 15

然后使用tf.random_shuffle对第一列进行随机打散,得到:

4 12
6 14
7 15
2 10
3 11
5 13
0  8
1  9

最后transpose然后reshape成长度为16的向量:

[4  6  7  2  3  5  0  1 12 14 15 10 11 13  8  9]

总结一下代码的效果:把长度为seq_len(=16)的向量分成seq_len/perm_size(=2)段,每段进行随机打散。

#得到特殊Token

1.首先是得到non_func_tokens,所谓的non_func_tokens是指SEP和CLS之外的”正常”的Token。

non_func_tokens = tf.logical_not(tf.logical_or(
      tf.equal(inputs, data_utils.SEP_ID),
      tf.equal(inputs, data_utils.CLS_ID)))
# 计算后的non_func_tokens为:

[True  True  True  True  True  True False  True  True  True  True  True True  True False False]

根据前面的输入inputs,第7个位置为SEP,15和16为SEP和CLS,因此这三个位置为False,其余都为True。

2.然后是non_mask_tokens,non_mask_tokens指的是”正常”的并且没有被Mask的Token。

non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
# non_mask_tokens的值为:
# [True  True  True  True False False False  True  True  True  True  True False False False False]

因此如果一个Token是non_mask_tokens,那么它首先是”正常”的Token(non_func_tokens为True),然后还得没有被Mask(is_masked为False)。

3.masked_or_func_tokens,它和non_mask_tokens相反,包括Masked的Token和SEP与CLS。

masked_or_func_tokens = tf.logical_not(non_mask_tokens)
# [False False False False True True True False False False False False True True True True]

#计算rev_index

# 把非Mask(也不是CLS和SEP)的Token的排列下标设置为最小的-1,这样:
# (1) 它们可以被所有其它的位置看到 
# (2) 它们看不到Masked位置,从而不会有信息泄露
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
rev_index = tf.where(non_mask_tokens, smallest_index, index)
# [-1 -1 -1 -1  3  5  0 -1 -1 -1 -1 -1 11 13  8  9]

tf.where函数的作用等价于:

for i in range(16):
    if non_mask_tokens[i]:
        rev_index[i]=smallest_index[i]
    else:
        smallest_index[i]=index[i]

这样得到的rev_index为:如果某个位置是非Mask的,则其值为-1;反正如果某个位置是Mask(5,6,13,14)或者为特殊的SEP/CLS(7,15,16),则值为前面随机生成的下标。

#target_mask

# 创建`target_mask`: 它是"普通"的并且被Masked的Token,它的值代表:
# 1: 值为1代表使用mask作为输入并且计算loss
# 0: 使用token(或者SEP/CLS)作为输入并且不计算loss
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
target_mask = tf.cast(target_tokens, tf.float32)
# [0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0.]

我们看到,is_masked为True的下标对应的target_mask为1,其余为0。

perm_mask

# `target_tokens` 不能看到自己
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
# [0  0  0  0  3  5  1  0  0  0  0  0 11 13  9 10]

因为target_tokens不能看到自己,因此它的值就是rev_index,而其余的加1(如果原来是-1,那么变成0,否则就是排列下标加一)。

# 1: 如果i <= j并且j不是非masked(masked或者特殊的SEP/CLS)则不能attend,因此值为1
# 0: 如果i > j或者j非masked,则为0
perm_mask = tf.logical_and(
	self_rev_index[:, None] <= rev_index[None, :],
	masked_or_func_tokens)
perm_mask = tf.cast(perm_mask, tf.float32)
# 值为:
[[0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0.]]

上面的代码”self_rev_index[:, None] <= rev_index[None, :]”我们来详细分析一下。

首先self_rev_index是一个长度为16的Tensor,我们使用self_rev_index[:, None]把它变成(16,1)的Tensor。

rev_index也是长度为的Tensor,rev_index[None, :]变成了(1,16)的Tensor。然后它们比较(“<=”)时会使用Broadcasting,我们可以了解为self_rev_index[:, None]变成了(16,16)的,每行都是相同的:

Broadcasting之前:
[[ 0] 
 [ 0]
 [ 0]
 [ 0]
 [ 3]
 [ 5]
 [ 1]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [11]
 [13]
 [ 9]
 [10]]
Broadcasting之后:
[[ 0, 0, 0, ....] 
 [ 0, 0, 0, ....]
 [ 0, 0, 0, ....]
 [ 0, 0, 0, ....]
 [ 3, 3, 3, ....]
 [ 5, 5, 5, ....]
 [ 1, 1, 1, ....]
 [ 0, 0, 0, ....] 
 [ 0, 0, 0, ....] 
 [ 0, 0, 0, ....] 
 [ 0, 0, 0, ....] 
 [ 0, 0, 0, ....] 
 [11, 11,11,....]
 [13,13, 13,....]
 [ 9, 9, 9, ....]
 [10, 10,10,....]]

类似的,rev_index[None, :]在Broadcasting之前为:

[-1 -1 -1 -1  3  5  0 -1 -1 -1 -1 -1 11 13  8  9]

而Broadcasting之后变成:

[[-1 -1 -1 -1  3  5  0 -1 -1 -1 -1 -1 11 13  8  9],
 [-1 -1 -1 -1  3  5  0 -1 -1 -1 -1 -1 11 13  8  9],
....

]

因此,最终得到的perm_mask(i,j)=1,则表示i不能attend to j。有两种情况i能attend to j:i的排列下标大于j(后面的可以attend to前面的);j没有被Mask在i也可以attend to j。而i不能attend to j需要同时满足:i<=j并且j被Mask了我们来看一个例子:perm_mask(3,4)=1,因为第3个Token的排列下标是2,第4个的排列下标是3,所以满足”2<3”。请读者检查确认一下j确实是被Mask了。

#new target

对于常规的语言模型来说,我们是预测下一个词,而XLNet是根据之前的状态和当前的位置预测被MASK的当前词。所以真正的new_targets要前移一个。

# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0: 1], targets[: -1]],
                  axis=0)
# [10 13 15 20 21 22  4 16 33 34 35 36 37 38 10 3]
inputs=tf.constant([10,13,15,20,21,22,4,16,33,34,35,36,37,38,4,3])
targets=tf.constant([13,15,20,21,22,4,16,33,34,35,36,37,38,10,3,3])

最终的返回值:

  # construct inputs_k
  inputs_k = inputs

  # construct inputs_q
  inputs_q = target_mask

  return perm_mask, new_targets, target_mask, inputs_k, inputs_q

#_local_perm的实现和论文的对比

如果读者还没有完全明白,我们再来把代码和论文的描述对比一下。论文在Partial Prediction部分提到:为了提高效率,我们只预测排列最后几个词。比如假设总共4个词,并且假设某次随机排列的顺序为3→2→4→13→2→4→1,我们假设预测(Mask)最后两个词,也就是预测第4和第1个词。那么1可以attend to [2,3,4];4可以attend to [2,3]。而非Mask(预测)的第2和3个词使用普通的Self-Attention,也就是可以互相看到所有的非Mask词(包括自己),因此2可以attend to [2,3],3也可以attend to [2,3]。

上面按照论文我们随机排列,然后选择最后的几个词作为Mask;而前面的代码我们先已经选定了要Mask的值,然后再排列,这就可能出现非Mask值的排列下标有可能Mask的值,比如假设我们选定Mask第2个和第3个词,但是随机的排列为:

2 3 1 4

也就是说第2个词在排列里的顺序是1,第3个词是2。那么按照论文应该是预测第1个和第4个词。那怎么办呢?代码使用了一个tricky,把所有的非Mask的词(第1个和第4个都变成-1),而Mask的不变(总是大于0的),因此Mask的词就排在后面了。非Mask的词互相都可以attend to但是非Mask的词不能attend to Mask的词。Mask的词可以attend to 非Mask的词而且后面的Mask的词也能attend to 前面的Mask的词。比如上面的例子2和3都是Mask的词,因为3在2后面,所以3可以attend to 2,但2不能attend to 3。同时,Mask的词不能attend to 自己(否则就是用自己预测自己了)。

如果读者还是没有理解这个函数也没有太多关系,但是至少要知道这个函数的干什么的,也就是输入输出是什么,具体怎么实现的可以当成黑盒。

总结一下,_local_perm返回的值:

  • perm_mask,64x64,表示经过重新排列后第i个token能否attend to 第j个token,1表示不能attend
  • target,64,表示真实的目标值,之前生成的target是预测下一个词,但是XLNet是预测当前词
  • target_mask,64,哪些地方是Mask的(需要预测的)
  • input_k, 64,content stream的初始值
  • input_q, 64, 哪些位置是需要计算loss的,如果不计算loss,也就不计算Query Stream。

第三部分:single_core_graph

这个函数的作用构建XLNet计算图的,是我们需要重点理解的部分。不过这个函数没几行代码:

def single_core_graph(is_training, features, mems):
  model_fn = get_model_fn()

  model_ret = model_fn(
      features=features,
      labels=None,
      mems=mems,
      is_training=is_training)

  return model_ret

它调用get_model_fn得到model function,然后调用这个函数返回total_loss, new_mems, grads_and_vars这3个Operation,下面我们来看get_model_fn函数。

get_model_fn

def get_model_fn():
  def model_fn(features, labels, mems, is_training):
    #### 根据输入计算loss
    total_loss, new_mems, monitor_dict = function_builder.get_loss(
        FLAGS, features, labels, mems, is_training)

    #### 打印模型的参数(便于调试)
    num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('#params: {}'.format(num_params))

    # GPU
    assert is_training
    # 得到所有可能训练的变量
    all_vars = tf.trainable_variables()
    # 计算梯度
    grads = tf.gradients(total_loss, all_vars)
    # 把梯度和变量配对,最终返回一个list,这个list的每一个元素都是(grad, var)的pair
    grads_and_vars = list(zip(grads, all_vars))

    return total_loss, new_mems, grads_and_vars

  return model_fn

这个函数主要是调用function_builder.get_loss得到total_loss, new_mems,然后计算梯度并且返回。

function_builder.get_loss

def get_loss(FLAGS, features, labels, mems, is_training):
  """Pretraining loss with two-stream attention Transformer-XL."""
  if FLAGS.use_bfloat16:
    with tf.tpu.bfloat16_scope():
      return two_stream_loss(FLAGS, features, labels, mems, is_training)
  else:
    return two_stream_loss(FLAGS, features, labels, mems, is_training)

我们这里使用普通的float(float32)而不是压缩的bfloat16,关于bfloat16,有兴趣的读者可以参考Using bfloat16 with TensorFlow models。最终它们都是调用two_stream_loss函数。

two_stream_loss

def two_stream_loss(FLAGS, features, labels, mems, is_training):
  # two-stream attention Transformer-XL模型的Pretraining loss

  #### Unpack input
  mem_name = "mems" 
  # mems是长度为6的list,每个元素是[96, 8(batch), 1024(hidden state)]
  mems = mems.get(mem_name, None)
  
  # input_k是(8,128)变成(128,8)
  inp_k = tf.transpose(features["input_k"], [1, 0])
  # input_q也是(8,128),变成(128,8) 
  inp_q = tf.transpose(features["input_q"], [1, 0])
  # seg_id也是(8,128),变成(128,8)
  seg_id = tf.transpose(features["seg_id"], [1, 0])

  inp_mask = None
  # 从(8,128,128)变成(128,128,8)
  perm_mask = tf.transpose(features["perm_mask"], [1, 2, 0])

  if FLAGS.num_predict is not None:
    # 从(8, 21, 128)变成(21(num_predict), 128, 8)
    target_mapping = tf.transpose(features["target_mapping"], [1, 2, 0])
  else:
    target_mapping = None

  # 语言模型loss的target,从(8,21)变成(21,8)
  tgt = tf.transpose(features["target"], [1, 0])

  # 语言模型losss的target mask,从(8,21)变成(21,8)
  tgt_mask = tf.transpose(features["target_mask"], [1, 0])

  # 构造xlnet的config然后保存到model_dir目录下
  # XLNetConfig包含某个checkpoint特定的超参数
  # 也就是pretraining和finetuing都相同的超参数
  xlnet_config = xlnet.XLNetConfig(FLAGS=FLAGS)
  xlnet_config.to_json(os.path.join(FLAGS.model_dir, "config.json"))

  # 根据FLAGS构造run config,它是XLNet模型Pretraining的超参数。
  run_config = xlnet.create_run_config(is_training, False, FLAGS)
  
  xlnet_model = xlnet.XLNetModel(
      xlnet_config=xlnet_config,
      run_config=run_config,
      input_ids=inp_k,
      seg_ids=seg_id,
      input_mask=inp_mask,
      mems=mems,
      perm_mask=perm_mask,
      target_mapping=target_mapping,
      inp_q=inp_q)

  output = xlnet_model.get_sequence_output()
  new_mems = {mem_name: xlnet_model.get_new_memory()}
  lookup_table = xlnet_model.get_embedding_table()

  initializer = xlnet_model.get_initializer()

  with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
    # LM loss
    lm_loss = modeling.lm_loss(
        hidden=output,
        target=tgt,
        n_token=xlnet_config.n_token,
        d_model=xlnet_config.d_model,
        initializer=initializer,
        lookup_table=lookup_table,
        tie_weight=True,
        bi_data=run_config.bi_data,
        use_tpu=run_config.use_tpu)

  #### Quantity to monitor
  monitor_dict = {}

  if FLAGS.use_bfloat16:
    tgt_mask = tf.cast(tgt_mask, tf.float32)
    lm_loss = tf.cast(lm_loss, tf.float32)

  total_loss = tf.reduce_sum(lm_loss * tgt_mask) / tf.reduce_sum(tgt_mask)
  monitor_dict["total_loss"] = total_loss

  return total_loss, new_mems, monitor_dict

XLNetConfig

XLNetConfig包含某个checkpoint特定的超参数,也就是pretraining和finetuing都相同的超参数。它包括:

  • n_layer: int, XLNet的层数,这里是6。
  • d_model: int, 隐单元个数,这里是1024。
  • n_head: int, attention head的个数,这里是16。
  • d_head: int, 每个attention head的大小,要求n_head*d_head=d_model,这里是64。
  • d_inner: int, 全连接网络的隐单元个数,这里是4096。
  • ff_activation: str, “relu”或者”gelu”。
  • untie_r: bool, 是否不同层不共享bias,这里为True,也就是每层都有独立的bias。
  • n_token: int, 词典大小,这里是32000。

RunConfig

这是Pretraining的一些超参数:

bi_data = {bool} True 表示一个句子会变成两个:一个是正常的,一个是逆序的(反过来)。
clamp_len = {int} -1
dropatt = {float} 0.1 attention的dropout
dropout = {float} 0.1 
init = {str} 'normal' 参数初始化方法,这里是正态分布
init_range = {float} 0.1 对于normal无效果
init_std = {float} 0.02 正态分布的方差
is_training = {bool} True 是否Pretraining
mem_len = {int} 96 
reuse_len = {int} 64
same_length = {bool} False
use_bfloat16 = {bool} False
use_tpu = {bool} False

XLNetModel

这是真正定义XLNet模型的类。它定义XLNet的代码都在构造函数里,其余的一些函数都是一些getter类函数,比如get_sequence_output可以拿到最后一层的隐状态。我们首先来看它的构造函数的参数。

#构造函数的参数

  • xlnet_config: XLNetConfig,XLNet模型结构的超参数,比如层数,head数量等等
  • run_config: RunConfig,运行时的超参数,包括dropout、初始范围等等。
  • input_ids: int32 Tensor,shape是[len, bsz], 输入token的ID
  • seg_ids: int32 Tensor,shape是[len, bsz], 输入的segment ID
  • input_mask: float32 Tensor,shape是[len, bsz], 输入的mask,0是真正的tokens而1是padding的
  • mems: list,每个元素是float32 Tensors,shape是[mem_len, bsz, d_model], 上一个batch的memory
  • perm_mask: float32 Tensor,shape是[len, len, bsz]。
    • 如果perm_mask[i, j, k] = 0,则batch k的第i个Token可以attend to j
    • 如果perm_mask[i, j, k] = 1, 则batch k的第i个Token不可以attend to j
    • 如果是None,则每个位置都可以attend to 所有其它位置(包括自己)
  • target_mapping: float32 Tensor,shape是[num_predict, len, bsz]
    • 如果target_mapping[i, j, k] = 1,则batch k的第i个要预测的是第j个Token,这是一种one-hot表示
    • 只是在pretraining的partial prediction时使用,finetuning时设置为None
  • inp_q: float32 Tensor,shape是[len, bsz]
    • 需要计算loss的(Mask的位置)为1,不需要的值为0,只在pretraining使用,finetuning时应为None

注意:这里的input_mask不是我们之前介绍的is_masked,is_masked是表示这个位置的Token是被预测的。而这里的input_mask其实指的是padding,我们这里是Pretraining,所有的Token都是真实的,因此传入的为None,后面的代码会把None当成没有padding处理。

#构造函数

  def __init__(self, xlnet_config, run_config, input_ids, seg_ids, input_mask,
               mems=None, perm_mask=None, target_mapping=None, inp_q=None,
               **kwargs):
    # 初始化类
    initializer = _get_initializer(run_config)
    # 所有超参数
    tfm_args = dict(
        n_token=xlnet_config.n_token,
        initializer=initializer,
        attn_type="bi",
        n_layer=xlnet_config.n_layer,
        d_model=xlnet_config.d_model,
        n_head=xlnet_config.n_head,
        d_head=xlnet_config.d_head,
        d_inner=xlnet_config.d_inner,
        ff_activation=xlnet_config.ff_activation,
        untie_r=xlnet_config.untie_r,

        is_training=run_config.is_training,
        use_bfloat16=run_config.use_bfloat16,
        use_tpu=run_config.use_tpu,
        dropout=run_config.dropout,
        dropatt=run_config.dropatt,

        mem_len=run_config.mem_len,
        reuse_len=run_config.reuse_len,
        bi_data=run_config.bi_data,
        clamp_len=run_config.clamp_len,
        same_length=run_config.same_length
    )
    # 所有输入
    input_args = dict(
        inp_k=input_ids,
        seg_id=seg_ids,
        input_mask=input_mask,
        mems=mems,
        perm_mask=perm_mask,
        target_mapping=target_mapping,
        inp_q=inp_q)
    tfm_args.update(input_args)

    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
      (self.output, self.new_mems, self.lookup_table
          ) = modeling.transformer_xl(**tfm_args)

    self.input_mask = input_mask
    self.initializer = initializer
    self.xlnet_config = xlnet_config
    self.run_config = run_config

上面的构造函数核心的代码其实只有一行:

(self.output, self.new_mems, self.lookup_table) = modeling.transformer_xl(**tfm_args)

#transformer_xl构造函数的参数

它的构造函数的参数是前面XLNetModel的构造函数传过来的,不过我们还是列举一下。

  • inp_k: int32 Tensor,shape是[len, bsz], 输入token的ID
  • seg_ids: int32 Tensor,shape是[len, bsz], 输入的segment ID
  • input_mask: float32 Tensor,shape是[len, bsz], 输入的mask,0是真正的tokens而1是padding的
  • mems: list,每个元素是float32 Tensors,shape是[mem_len, bsz, d_model], 上一个batch的memory
  • perm_mask: float32 Tensor,shape是[len, len, bsz]。
    • 如果perm_mask[i, j, k] = 0,则batch k的第i个Token可以attend to j
    • 如果perm_mask[i, j, k] = 1, 则batch k的第i个Token不可以attend to j
    • 如果是None,则每个位置都可以attend to 所有其它位置(包括自己)
  • target_mapping: float32 Tensor,shape是[num_predict, len, bsz]
    • 如果target_mapping[i, j, k] = 1,则batch k的第i个要预测的是第j个Token,这是一种one-hot表示
    • 只是在pretraining的partial prediction时使用,finetuning时设置为None
  • inp_q: float32 Tensor,shape是[len, bsz]
    • 需要计算loss的(Mask的位置)为1,不需要的值为0,只在pretraining使用,finetuning时应为None
  • n_layer: int, XLNet的层数,这里是6。
  • d_model: int, 隐单元个数,这里是1024。
  • n_head: int, attention head的个数,这里是16。
  • d_head: int, 每个attention head的大小,要求n_head*d_head=d_model,这里是64。
  • d_inner: int, 全连接网络的隐单元个数,这里是4096。
  • ff_activation: str, “relu”或者”gelu”。
  • untie_r: bool, 是否不同层不共享bias,这里为True,也就是每层都有独立的bias。
  • n_token: int, 词典大小,这里是32000。

  • is_training: bool, 是否是Training
  • use_tpu: bool, 是否使用TPU
  • use_bfloat16: bool, 是否用bfloat16替代float32
  • dropout: float, dropout大小.
  • dropatt: float, attention概率的dropout
  • init: str, 初始化方法,值为”normal”或者”uniform”
  • init_range: float, 均匀分布的范围,[-init_range, init_range]。只有init=”uniform”时有效
  • init_std: float, 正态分布的方程。只有init=”normal”时有效
  • mem_len: int, cache的token个数
  • reuse_len: int, 当前batch里reuse的数量,参考第一部分。
  • bi_data: bool, 是否双向处理输入,通常在pretraining设置为True,而finetuning为False
  • clamp_len: int, clamp掉相对距离大于clamp_len的attention,-1表示不clamp
  • same_length: bool, 是否对于每个Token使用相同的attention长度
  • summary_type: str, “last”, “first”, “mean”和”attn”。怎么把最后一层的多个向量整合成一个向量
  • initializer: 初始化器
  • scope: 计算图的scope

#transformer_xl构造函数

第一段

  with tf.variable_scope(scope):
    if untie_r:
      # 我们这里是不共享bias
      # r_w_bias和r_r_bias都是[6, 16, 64]
      r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                 dtype=tf_float, initializer=initializer)
      r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                 dtype=tf_float, initializer=initializer)
    else:
      r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                 dtype=tf_float, initializer=initializer)
      r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                 dtype=tf_float, initializer=initializer)

    bsz = tf.shape(inp_k)[1]    # 8
    qlen = tf.shape(inp_k)[0] # 128
    mlen = tf.shape(mems[0])[0] if mems is not None else 0 # 96
    klen = mlen + qlen # 224

上面的代码注意是定义r_w_bias和r_r_bias,以及读取一些超参数。

第二段

    ##### Attention mask
    # 因果关系的(causal)attention mask
    if attn_type == 'uni':
      attn_mask = _create_mask(qlen, mlen, tf_float, same_length)
      attn_mask = attn_mask[:, :, None, None]
    elif attn_type == 'bi':
      # 我们走的这个分支
      attn_mask = None
    else:
      raise ValueError('Unsupported attention type: {}'.format(attn_type))

    # data mask: input mask & perm mask
    if input_mask is not None and perm_mask is not None:
      data_mask = input_mask[None] + perm_mask
    elif input_mask is not None and perm_mask is None:
      data_mask = input_mask[None]
    elif input_mask is None and perm_mask is not None:
      # data_mask=perm_mask [128,128,8]
      data_mask = perm_mask
    else:
      data_mask = None

    if data_mask is not None:
      # 所有的mems的可以attended to,因此构造[128,96,8]的zeros
      mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
                           dtype=tf_float)
      # 然后拼接成[128,224,8]的data_mask。
      # data_mask[i,j,k]=0则表示batch k的第i(0-128)个Token可以attend to 
      # 第j(0-224,包括前面的mem)个Token。
      # 注意j的下标范围,0-95表示mem,96-223表示128个输入。
      data_mask = tf.concat([mems_mask, data_mask], 1)
      if attn_mask is None:
        # attn_mask为[128, 224, 8, 1]
        attn_mask = data_mask[:, :, :, None]
      else:
        attn_mask += data_mask[:, :, :, None]

    if attn_mask is not None:
      # 把attn_mask里大于0的变成1(前面相加有可能大于1)
      attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)

    if attn_mask is not None:
      # 参考下面
      non_tgt_mask = -tf.eye(qlen, dtype=tf_float)
      non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float),
                                non_tgt_mask], axis=-1)
      non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                             dtype=tf_float)
    else:
      non_tgt_mask = None

下面来看一下non_tgt_mask,为了简单,我们假设qlen是4, mlen是3,前两行的结果为:

0 0 0   -1 0 0 0
0 0 0   0 -1 0 0
0 0 0   0 0 -1 0
0 0 0   0 0 0 -1

attn_mask是(qlen, qlen+mlen, batch, 1),它和non_tgt_mask[:,:,None,None]相加。它的作用是让Token能attend to 自己,除后面的对角线外,non_tgt_mask等于attn_mask,而对角线的位置由1变成了0,从而让它可以attend自己。注意:在XLNet的代码里,0表示可以attend to,而1表示不能attend to。non_tgt_mask用于Content Stream,它不是预测的目标(target),所以可以利用自己的信息。而Query Stream就必须使用attn_mask,预测第i个Token时不能利用自己的内容。

第三段

    ##### Word embedding
    # 输入inp_k是(128,8),embedding后的word_emb_k是(128,8,1024)
    # lookup_table是(32000, 1024)
    word_emb_k, lookup_table = embedding_lookup(
        x=inp_k,
        n_token=n_token,
        d_embed=d_model,
        initializer=initializer,
        use_tpu=use_tpu,
        dtype=tf_float,
        scope='word_embedding')
    # inp_q是(128,8)表示某个位置是否要计算loss
    if inp_q is not None:
      with tf.variable_scope('mask_emb'):
        # mask_emb是[1, 1, 1024]
        mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], dtype=tf_float)
        if target_mapping is not None:
          # tile(复制)成[21, 8, 1024]
          word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
        else:
          inp_q_ext = inp_q[:, :, None]
          word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
    # output_h是word_emb_k的dropout的结果,shape也是[128, 8, 1024]
    output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training)
    if inp_q is not None:
      # output_g是word_emb_q的dropout,shape也是[21, 8, 1024]
      output_g = tf.layers.dropout(word_emb_q, dropout, training=is_training)

上面的output_h和output_g分别是two-stream的初始输入。

第四段

    ##### Segment embedding
    if seg_id is not None:
      if untie_r:
        # [6, 16, 64]
        r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head],
                                   dtype=tf_float, initializer=initializer)
      else:
        # default case (tie)
        r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                   dtype=tf_float, initializer=initializer)
      # [6, 2, 16, 64]
      seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head],
                                  dtype=tf_float, initializer=initializer)

      # Convert `seg_id` to one-hot `seg_mat`
      mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
      # 输入需要padding 96个0,最终变成(224, 8)
      cat_ids = tf.concat([mem_pad, seg_id], 0)

      # `1` indicates not in the same segment [qlen x klen x bsz]
      # 参考下面的说明。
      seg_mat = tf.cast(
          tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
          tf.int32)
      seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
    else:
      seg_mat = None

seg_embed是Segment embedding,XLNet是相对Segment编码:如果两个Token在同一个Segment则是0;否则是1。所以第二维是2。

在阅读seg_mat那段时我们首先需要熟悉Tensorflow增加维度的方法。

seg_id是[128, 8],那么seg_id[:, None]的shape呢?有的读者可能会猜测是[128, 8, 1](我一开始也是这么以为),但这是不对的。[:, None]的意思是在第二个维度增加一维,而原来的第二维(8)被往后推到第三维了,因此seg_id[:, None]的shape是[128, 1, 8],它等价于seg_id[:, None, :]。而cat_ids[None, :]是在第一个维度增加一维,因此是[1, 224, 8]。

接下来tf.equal(seg_id[:, None], cat_ids[None, :])会首先进行broadcasting:

seg_id[:, None]: [128, 1, 8] -> [128, 224, 8] 
cat_ids[None, :]: [1, 224, 8] -> [128, 224, 8]

计算的结果是:如果(i,j)的seg_id相同,则为True,否则为False。注意i的取值范围是0-127;而j是0-223。因为我们计算Attention时是用128个去attend to 224(包括96个mem)。最后在用tf.logical_not反过来:1表示在不同Segment而0表示同一个Segment。

最后表示成one-hot的方式(在同一个Segment为[1,0];不同的Segment为[0,1]),变成[128, 224, 8, 2],第四位就是one-hot。

第五段

相对位置编码的基本概念、相对位置编码的详细计算过程

    ##### 位置编码,下面我们会详细介绍,它的输出是(352, 8, 1024)
    pos_emb = relative_positional_encoding(
        qlen, klen, d_model, clamp_len, attn_type, bi_data,
        bsz=bsz, dtype=tf_float)
    # dropout
    pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)

我们回到前面的例子,context大小是96,而输入的序列长度是128。因此query的下标i的取值范围是96-223,而key的下标j的取值范围是0-223,所以它们的差的取值范围是(96-223)-(223-0),所以位置差总共有352种可能的取值,所以上面返回的pos_emb的shape是(352, 8, 1024)。

relative_positional_encoding函数

def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type,
                                 bi_data, bsz=None, dtype=None):
  """创建相对位置编码"""
  # [0,2,...,1022] 长度为d_model/2=512
  freq_seq = tf.range(0, d_model, 2.0)
  if dtype is not None and dtype != tf.float32:
    freq_seq = tf.cast(freq_seq, dtype=dtype)
  # inv_freq的大小还是512
  inv_freq = 1 / (10000 ** (freq_seq / d_model))

  if attn_type == 'bi':
    # 我们这里attn_type == 'bi'
    # beg, end = 224, -128
    beg, end = klen, -qlen
  elif attn_type == 'uni':
    # beg, end = klen - 1, -1
    beg, end = klen, -1
  else:
    raise ValueError('Unknown `attn_type` {}.'.format(attn_type))

  if bi_data:
    # [224, -127]
    fwd_pos_seq = tf.range(beg, end, -1.0)
    # [-224, 127]
    bwd_pos_seq = tf.range(-beg, -end, 1.0)

    if dtype is not None and dtype != tf.float32:
      fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
      bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)

    if clamp_len > 0:
      # 把两个词的最大距离限制在-clamp_len和clamp_len之间,我们这里没有限制
      fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
      bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len)

    if bsz is not None:
      # 参考下面,它的返回是(352, 4, 1024)
      fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
      # 返回(352, 4, 1024)
      bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
    else:
      fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq)
      bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq)
    # (352, 8, 1024)
    pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
  else:
    fwd_pos_seq = tf.range(beg, end, -1.0)
    if dtype is not None and dtype != tf.float32:
      fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
    if clamp_len > 0:
      fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
    pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz)

  return pos_emb

这个函数返回相对位置的编码,它返回一个(352, 8, 1024)的Tensor。8表示batch,1024表示位置编码的维度,它是和隐单元相同大小,这样可以相加(残差连接)。

前面我们介绍过,这里再复述一下。context大小是96,而输入的序列长度是128。因此query的下标i的取值范围是96-223,而key的下标j的取值范围是0-223,所以它们的差的取值范围是(96-223)-(223-0),所以位置差总共有352种可能的取值,所以上面返回的pos_emb的shape是(352, 8, 1024)。fwd_pos_seq的范围是[224, -127],表示当i>=j时(从后往前)从i attend to j的最大范围是223 attend to 0(加上attend to 自己总共224个可能取值);而i<j时最小值是-127。参考下图:

计算正弦的位置编码是通过下面介绍的函数positional_embedding来实现的,读者可以对照Transformer代码阅读·位置编码来阅读,那里是PyTorch的实现。

tf.einsum函数

在介绍positional_embedding函数之前我们介绍一下Tensorflow的einsum函数,这个函数在下面很多地方都会用到,所以这里单独作为一个小节来介绍一下。

这个函数返回一个Tensor,这个Tensor的元素由Einstein summation习惯的简写公式定义。所谓的Einstein summation指的是由Albert Einstein引入的一种公式简写法。比如计算矩阵A和B的乘积得到C。则C的元素的定义为:

    C[i,k] = sum_j A[i,j] * B[j,k]

如果读者熟悉Latex的公式,则上面内容在Latex里会显示为:

上面的这个公式简写为:

    ij,jk->ik

那这种简写方法是什么意思呢?或者说怎么从上面的公式得到简写的呢?它的过程如下:

  • 删除变量、括号和逗号
  • 把”*“变成”,”
  • 去掉求和符号
  • 把输出移动右边,把”=”变成”->”

下面我们按照前面的4个步骤看看这个过程:

C[i,k] = sum_j A[i,j] * B[j,k] 删除变量、括号和逗号变成,
ik = sum_j ij * jk 把"*"变成","
ik = sum_j ij , jk 去掉求和符号
ik = ij,jk 把输出放到右边,把"="变成"->"
ij,jk->ik

下面我们来看一些例子,请读者理解它们的含义,如果不理解,可以按照前面的方法一步一步来。

# 矩阵乘法,这个前面已经介绍过
>>> einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

# 向量的内积,注意输出是空,表示输出没有下标,因为输出是一个数值。
>>> einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

# 向量外积
>>> einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

# 矩阵转置
>>> einsum('ij->ji', m)  # output[j,i] = m[i,j]

# Batch 矩阵乘法
>>> einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]

positional_embedding

def positional_embedding(pos_seq, inv_freq, bsz=None):
  # 计算pos_seq和inv_freq的外积
  # 352 x 512
  sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq)
  # 计算sin和cos,然后拼接成(352, 1024)
  pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
  # 变成(352, 1, 1024) 如果不懂None的含义请参考第二部分
  # 它等价于tf.reshape(pos_emb, [352, 1, 1024])
  pos_emb = pos_emb[:, None, :]

  if bsz is not None:
    # 对每个batch复制,变成(352, 8, 1024)
    pos_emb = tf.tile(pos_emb, [1, bsz, 1])

  return pos_emb

相对位置编码向量的获得就介绍到这,后面在Attention的计算是会用到pos_emb,我们到时候再看具体的计算。

第六段

这一对是核心的计算two stream attention的地方。

    ##### Attention layers
    if mems is None:
      mems = [None] * n_layer

    for i in range(n_layer):
      # 通过当前的输出计算新的mem,也就是保留128个中的后96个
      # 当然如果cache的大小大于当前的输入,那么mem里可能包含当前输入的全部以及之前的mem
      new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len))

      # segment bias
      if seg_id is None:
        r_s_bias_i = None
        seg_embed_i = None
      else:
        # 这里是不共享bias的,所以从r_s_bias(6, 16, 64)取出
        # 当前层的bias(16, 64)
        r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
        # 从(6, 2, 16, 64)取出当前层的segment embedding
        # seg_embed_i为(2, 16, 64)
        seg_embed_i = seg_embed[i]

      with tf.variable_scope('layer_{}'.format(i)):
        # inp_q表示哪些位置是Mask从而计算loss,也表示这是Pretraining
        if inp_q is not None:
          # 这是计算two stream attention,下面会详细介绍
          # 它的返回值output_h表示
          # output_g表示
          output_h, output_g = two_stream_rel_attn(
              h=output_h,
              g=output_g,
              r=pos_emb,
              r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
              r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
              seg_mat=seg_mat,
              r_s_bias=r_s_bias_i,
              seg_embed=seg_embed_i,
              attn_mask_h=non_tgt_mask,
              attn_mask_g=attn_mask,
              mems=mems[i],
              target_mapping=target_mapping,
              d_model=d_model,
              n_head=n_head,
              d_head=d_head,
              dropout=dropout,
              dropatt=dropatt,
              is_training=is_training,
              kernel_initializer=initializer)
          reuse = True
        else:
          reuse = False

          output_h = rel_multihead_attn(
              h=output_h,
              r=pos_emb,
              r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
              r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
              seg_mat=seg_mat,
              r_s_bias=r_s_bias_i,
              seg_embed=seg_embed_i,
              attn_mask=non_tgt_mask,
              mems=mems[i],
              d_model=d_model,
              n_head=n_head,
              d_head=d_head,
              dropout=dropout,
              dropatt=dropatt,
              is_training=is_training,
              kernel_initializer=initializer,
              reuse=reuse)

这段代码把下一层的输出更新mem,然后把它作为上一层的输入。核心是使用two_stream_rel_attn函数计算Attention。我们先看mem的更新。

_cache_mem函数

_cache_mem函数代码如下:

def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
  """把隐状态chache进memory"""
  if mem_len is None or mem_len == 0:
    return None
  else:
    if reuse_len is not None and reuse_len > 0:
      # reuse的部分cache,curr_out从(128,...)变成(64,....)
      # 原因是生成数据时我们每次后移64,请参考第一部分代码
      curr_out = curr_out[:reuse_len]

    if prev_mem is None:
      new_mem = curr_out[-mem_len:]
    else:
      #之前的mem(96)和当前的输出(64)拼起来然后取后面mem_len(96)个
      new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]

  # cache的mem不参与梯度计算
  return tf.stop_gradient(new_mem)

two_stream_rel_attn函数

我们先看输入参数:

  • h (128, 8, 1024) 上一层的内容(context)隐状态
  • g (21, 8, 1204) 上一层的查询(query)隐状态,只需要计算Mask的哪些位置
  • r (352, 8, 1024) 相对位置编码,函数relative_positional_encoding的返回值
  • r_w_bias (16, 64)
  • r_r_bias (16, 64)
  • seg_mat (128, 224, 8, 2) 表示i(范围是0-127)和j(被attend to的,包括mem因此范围是0-223)是否在同一个segment里,1是True
  • r_s_bias (16, 64)
  • seg_embed (2, 16, 64) 表示两个Token处于相同或者不同Segment(只有两种可能)的embedding
  • attn_mask_h (128, 224, 8, 1) 这是内容stream的mask,attn_mask_h(i,j)表示i能否attend to j,1表示不能
  • attn_mask_g (128, 224, 8, 1) 这是查询stream的mask,attn_mask_h(i,j)表示i能否attend to j,1表示不能
  • target_mapping (21, 128, 8) 表示21个Mask的Token下标,使用one-hot编码(1 out of 128)的方式,注意如果Mask的不够21个,则padding的内容全是零(而不是有且仅有一个1)。
  • d_model 1024 隐状态大小
  • n_head 16 attention head个数
  • d_head 64 每个head的隐状态大小
  • dropout dropout
  • dropatt Attention的dropout
  • is_traning True,表示Pretraining
  • kernel_initializer 初始化类

完整的代码如下:

def two_stream_rel_attn(h, g, r, mems, r_w_bias, r_r_bias, seg_mat, r_s_bias,
                        seg_embed, attn_mask_h, attn_mask_g, target_mapping,
                        d_model, n_head, d_head, dropout, dropatt, is_training,
                        kernel_initializer, scope='rel_attn'):
  """基于相对位置编码的Two-stream attention"""
  # scale,参考《Transformer图解》
  scale = 1 / (d_head ** 0.5)
  with tf.variable_scope(scope, reuse=False):

    # 内容attention score
    if mems is not None and mems.shape.ndims > 1:
      # 把之前的mem加上得到\hat{h}
      # shape是(96+128=224, 8, 1024)
      cat = tf.concat([mems, h], 0)
    else:
      cat = h

    # 计算内容stream的attention head的key
    # 输入是(224, 8, 1024),使用key的变换矩阵后,输出是(224, 8, 16, 64)
    # head_projection的讲解在下一小节
    k_head_h = head_projection(
        cat, d_model, n_head, d_head, kernel_initializer, 'k')

    # 类似的的计算内容stream的value
    # 输入是(224, 8, 1024),输出是(224, 8, 16, 64)
    v_head_h = head_projection(
        cat, d_model, n_head, d_head, kernel_initializer, 'v')

    # 相对位置的key也做投影变换,因此输入的第一维表示位置差
    # 输入是(352, 8, 1024),输出是(352, 8, 16, 64)
    k_head_r = head_projection(
        r, d_model, n_head, d_head, kernel_initializer, 'r')

    ##### 计算内容stream的attention
    # 内容stream的query
    # query的范围是当前输入,因此输入是(128, 8, 1024)
    # 输出是(128, 8, 16, 64) 
    q_head_h = head_projection(
        h, d_model, n_head, d_head, kernel_initializer, 'q')

    # 计算Attention的核心函数,下面详细介绍
    # 输出是attention score, shape是(128, 8, 16, 64)
    # 表示128个输入(8是batch)的隐状态,每个隐状态是16(个head)x64(head大小)
    attn_vec_h = rel_attn_core(
        q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
        r_r_bias, r_s_bias, attn_mask_h, dropatt, is_training, scale)

    # 后处理,残差连接+LayerNorm,详见后文
    output_h = post_attention(h, attn_vec_h, d_model, n_head, d_head, dropout,
                              is_training, kernel_initializer)

  with tf.variable_scope(scope, reuse=True):
    ##### 查询的stream
    # query向量,shape是(21, 8, 16, 64)
    # 只需要计算Mask的(21)
    q_head_g = head_projection(
        g, d_model, n_head, d_head, kernel_initializer, 'q')

    # 核心的attention运算
    if target_mapping is not None:
      # q_head_g是(21, 8, 16, 64) target_mapping是(21, 128, 8)
      # q_head_g变成(128, 8, 16, 64),也就是根据target_mapping把输入从21变成128
      # 当然21个的位置是有query head的,而128-21个位置的query都是0。
      # 这么做的原因是因为模型的定义里输入都是128。
      q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)

      # Attention计算,attn_vec_g是(128, 8, 16, 64),这和前面基本一样
      attn_vec_g = rel_attn_core(
          q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
          r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale)
      # 但是我们需要的只是21个Mask的位置的输出,因此我们又用target_mapping变换回来
      # attn_vec_g是(128, 8, 16, 64) target_mapping是(21, 128, 8)
      # 最终的attn_vec_g是(21, 8, 16, 64)
      attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
    else:
      attn_vec_g = rel_attn_core(
          q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
          r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale)

    # 后处理,残差+layernorm,和前面content流的一样
    # 最终的output_g是(21, 8, 1024)
    output_g = post_attention(g, attn_vec_g, d_model, n_head, d_head, dropout,
                              is_training, kernel_initializer)

    return output_h, output_g

head_projection函数

def head_projection(h, d_model, n_head, d_head, kernel_initializer, name):
  # proj_weight的shape是(1024, 16, 64)
  proj_weight = tf.get_variable('{}/kernel'.format(name),
                                [d_model, n_head, d_head], dtype=h.dtype,
                                initializer=kernel_initializer)
  # h是(224, 8, 1024),proj_weight是(1024, 16, 64),head是(224, 8, 16, 64)
  head = tf.einsum('ibh,hnd->ibnd', h, proj_weight)

  return head

head_projection函数的作用是把输入的1024维的向量乘以proj_weight变成(16, 64)。为了实现上面的计算,我们可以使用reshape:

首先把proj_weight从(1024, 16, 64)->(1024, 1024(16x64))
然后把h从(224, 8, 1024) -> (224*8,1024)
然后h乘以proj_weight得到(224*8, 1024(16x64))
最后reshape成(224, 8, 16, 64)

而通过einsum可以一步到位,实现上面的计算过程(当然阅读起来稍微难以理解一点)。

我们可以这样解读’ibh,hnd->ibnd’:输入h的shape是(i/224, b/8, h/1024),proj_weight是(h/1024, n/16, d/64),输出的shape是(i/224,b/8,n/16,d/64)。并且:

抛开那些无关的i,b,n,d维度,它的核心就是把输入的1024/h使用一个矩阵(proj_weight)变换成(16,64)。请读者仔细理解einsum的含义,下文还会阅读很多tf.einsum函数,作者不会这样详细分解了,请确保理解后再往下阅读。

rel_attn_core函数

def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
                  r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training,
                  scale):
  """基于相对位置编码的Attention"""

  # 基于内容的attention score
  # q_head是(128, 8, 16, 64) k_head_h是(224, 8, 16, 64)
  # einsum是把最后的64x64求内积,得到ac是(128, 224, 8, 16)
  # 表示i(128)->j(224)的attention score
  ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)

  # 基于位置的attention score
  # q_head是(128, 8, 16, 64) k_head_r是(352, 8, 16, 64)
  # 得到的bd是(128, 352, 8, 16)
  bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
  bd = rel_shift(bd, klen=tf.shape(ac)[1])

  # segment的embedding
  if seg_mat is None:
    ef = 0
  else:
    # q_head是(128, 8, 16, 64) seg_embed是(2, 16, 64)
    # ef是(128, 8, 16, 2),也就是把它们的最后一维做内积
    ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
    # seg_mat是(128, 224, 8, 2) ef是(128, 8, 16, 2)
    # 最终的ef是(128, 224, 8, 16) 也是把最后一维做内积
    # ef(i,j)表示i attend to j时的segment embedding。
    ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)

  # ac+bd+ef得到最终的attention score
  attn_score = (ac + bd + ef) * scale
  if attn_mask is not None:
    # 回忆一下,attn_mask(i,j)为1表示attention mask,也就是i不能attend to j。
    # 下面的式子中如果attn_mask里为0的不变,为1的会减去一个很大的数,这样就变成很大的负数
    # 从而后面softmax的时候概率基本为0,从而实现Mask
    attn_score = attn_score - 1e30 * attn_mask

  # 把score变成概率
  # attn_prob是(128, 224, 8, 16)
  attn_prob = tf.nn.softmax(attn_score, 1)
  # 使用dropatt进行dropout
  attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)

  # 根据attn_prob和value向量计算最终的输出
  # attn_prob是(128, 224, 8, 16), v_head_h是(224, 8, 16, 64)
  # attn_vec是(128, 8, 16, 64),也就是把224个attend to的向量使用attn_prob加权求和
  attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)

  return attn_vec

变量ac对应公式的(a)和(c),变量r_w_bias对应公式里的uu;而bd对应公式的(b)和(d),变量r_r_bias对应公式里的vv。请读者对照公式进行理解。而ef是相对Segment Embedding。最终把它们加起来就是Attention Score。

rel_shift函数

def rel_shift(x, klen=-1):
  # x是(128, 352, 8, 16),klen=224
  x_size = tf.shape(x)
  # reshape成(352, 128, 8, 16)
  x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
  # 第一维从下标1开始slice到最后,其余的都完全保留,
  # tf.slice函数的介绍在下面
  # x变成(351, 128, 8, 16)
  x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
  # reshape成(128, 351, 8, 16)
  x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
  # slice得到(128, 224, 8, 16)
  x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])

  return x

tf.slice的作用是从一个Tensor里截取(slice)一部分,它的主要参数是第二个和第三个,分别表示每一维开始的下标和每一维的长度,如果长度为-1,则表示一直截取从开始下标到最后所有的。下面是几个例子:

t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])
tf.slice(t, [1, 0, 0], [1, 1, 3])  # [[[3, 3, 3]]]
tf.slice(t, [1, 0, 0], [1, 2, 3])  # [[[3, 3, 3],
                                   #   [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [2, 1, 3])  # [[[3, 3, 3]],
                                   #  [[5, 5, 5]]]

x是(128, 352, 8, 16),这个x是前面bd,它是relative_positional_encoding函数的输出经过经过head_projection得到的k_head_r和q_head计算得到的。relative_positional_encoding函数的输出是(352, 8, 1024),而第一维的128是tile(复制)出来的,因此都是相同的(包括batch/8的维度也是tile出来的)。

x(i,d,…)表示i(输入下标,0-127)的距离为d(0-351)的相对位置编码。但是我们实际需要的是x(i,j,…),j是实际的下标,因此需要把距离d变成真实的下标j(i-d),这就是这个函数的作用。注意:d为0表示最大的距离224(其实最大只能是223)。

如果读者没有办法理解其中的细节也可以暂时跳过,但是至少需要知道输出(128,224)的含义:(i,j)的含义是位置i attend to j的位置编码,x(1,4)的值和x(3,6)的值是相同的,因为它们的相对位置都是-3(1-3和4-6)。

post_attention

需要注意的地方都在注释里,请仔细阅读:

def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training,
                   kernel_initializer, residual=True):
  """attention的后处理"""
  # 把attention得到的向量(16x64)投影回d_model(1024)
  # proj_o是(1024, 16, 64)
  proj_o = tf.get_variable('o/kernel', [d_model, n_head, d_head],
                           dtype=h.dtype, initializer=kernel_initializer)
  # attn_vec是(128, 8, 16, 64) proj_o是(1024, 16, 64)
  # attn_out是(128, 8, 1024)
  attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o)

  attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
  if residual:
    # 残差连接然后是layer norm
    # 输出大小不变,仍然是(128, 8, 1024)
    output = tf.contrib.layers.layer_norm(attn_out + h, begin_norm_axis=-1,
                                          scope='LayerNorm')
  else:
    output = tf.contrib.layers.layer_norm(attn_out, begin_norm_axis=-1,
                                          scope='LayerNorm')

  return output

第六段

这一段是Self-Attention之后的全连接层+LayerNorm。

        if inp_q is not None:
          output_g = positionwise_ffn(
              inp=output_g,
              d_model=d_model,
              d_inner=d_inner,
              dropout=dropout,
              kernel_initializer=initializer,
              activation_type=ff_activation,
              is_training=is_training)

        output_h = positionwise_ffn(
            inp=output_h,
            d_model=d_model,
            d_inner=d_inner,
            dropout=dropout,
            kernel_initializer=initializer,
            activation_type=ff_activation,
            is_training=is_training,
            reuse=reuse)

代码比较简单,就是调用positionwise_ffn函数,下面是positionwise_ffn的代码:

def positionwise_ffn(inp, d_model, d_inner, dropout, kernel_initializer,
                     activation_type='relu', scope='ff', is_training=True,
                     reuse=None):
  if activation_type == 'relu':
    activation = tf.nn.relu
  elif activation_type == 'gelu':
    activation = gelu
  else:
    raise ValueError('Unsupported activation type {}'.format(activation_type))

  output = inp
  with tf.variable_scope(scope, reuse=reuse):
    # 第一个全连接层,输入output是(21,8,1024),输出output是(21,8,4096)
    # 激活函数是relu
    output = tf.layers.dense(output, d_inner, activation=activation,
                             kernel_initializer=kernel_initializer,
                             name='layer_1')
    output = tf.layers.dropout(output, dropout, training=is_training,
                               name='drop_1')
    # 用一个线性变换(activation是None)把4096->1024
    # output为(21,8,1024)
    output = tf.layers.dense(output, d_model,
                             kernel_initializer=kernel_initializer,
                             name='layer_2')
    output = tf.layers.dropout(output, dropout, training=is_training,
                               name='drop_2')
    # 残差连接和LayerNorm,最终的output还是(21, 8, 1024)
    output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1,
                                          scope='LayerNorm')
  return output

第七段

    if inp_q is not None:
      # 因为是pretraining,所有返回query stream的结果
      output = tf.layers.dropout(output_g, dropout, training=is_training)
    else:
      output = tf.layers.dropout(output_h, dropout, training=is_training)

    return output, new_mems, lookup_table

这一段非常简单,只是返回结果。如果是pretraining,则返回查询stream的输出output_g(21, 8, 1024),否则返回内容stream的输出output_h(128, 8, 1024),返回之前会再做一个dropout,最终返回的是output, new_mems和lookingup_table(32000, 1024)。

返回two_stream_loss函数

终于把最复杂的transformer_xl函数分析完了,我们返回上一层回到XLNetModel的构造函数,再往上返回two_stream_loss函数:

def two_stream_loss(FLAGS, features, labels, mems, is_training):
  .........................
  # 这是我们之前分析的"断点",我们接着
  xlnet_model = xlnet.XLNetModel(
      xlnet_config=xlnet_config,
      run_config=run_config,
      input_ids=inp_k,
      seg_ids=seg_id,
      input_mask=inp_mask,
      mems=mems,
      perm_mask=perm_mask,
      target_mapping=target_mapping,
      inp_q=inp_q)
  
  # 得到pretraining的输出,shape是(21, 8, 1024)
  output = xlnet_model.get_sequence_output()
  # 新的mem,key是"mems",value是(96, 8, 1024)
  new_mems = {mem_name: xlnet_model.get_new_memory()}
  # lookup_table是(32000, 1024)
  lookup_table = xlnet_model.get_embedding_table()

  initializer = xlnet_model.get_initializer()

  with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
    # pretraining的loss,详细参考下面的分析
    lm_loss = modeling.lm_loss(
        hidden=output,
        target=tgt,
        n_token=xlnet_config.n_token,
        d_model=xlnet_config.d_model,
        initializer=initializer,
        lookup_table=lookup_table,
        tie_weight=True,
        bi_data=run_config.bi_data,
        use_tpu=run_config.use_tpu)

  #### 监控的量
  monitor_dict = {}

  if FLAGS.use_bfloat16:
    tgt_mask = tf.cast(tgt_mask, tf.float32)
    lm_loss = tf.cast(lm_loss, tf.float32)
  # 所有的平均loss
  total_loss = tf.reduce_sum(lm_loss * tgt_mask) / tf.reduce_sum(tgt_mask)
  monitor_dict["total_loss"] = total_loss

  return total_loss, new_mems, monitor_dict

上面的代码注意就是调用modeling.lm_loss根据xlnet_model的输出计算loss。

lm_loss函数

def lm_loss(hidden, target, n_token, d_model, initializer, lookup_table=None,
            tie_weight=False, bi_data=True, use_tpu=False):

  with tf.variable_scope('lm_loss'):
    if tie_weight:
      assert lookup_table is not None, \
          'lookup_table cannot be None for tie_weight'
      softmax_w = lookup_table
    else:
      softmax_w = tf.get_variable('weight', [n_token, d_model],
                                  dtype=hidden.dtype, initializer=initializer)

    softmax_b = tf.get_variable('bias', [n_token], dtype=hidden.dtype,
                                initializer=tf.zeros_initializer())
    # hidden (21, 8, 1024) softmax_w是(32000, 1024)
    # tf.einsum得到(21, 8, 32000),然后加上32000的bias,shape不变
    logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b

    if use_tpu:
      # TPU上稀疏Tensor的效率不高
      one_hot_target = tf.one_hot(target, n_token, dtype=logits.dtype)
      loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
    else:
      # 计算loss
      loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
                                                            logits=logits)

    return loss

上面的代码比较简单,使用tf.einsum把隐状态(21, 8, 1024)变换成(21, 8, 32000)的logits,然后使用sparse_softmax_cross_entropy_with_logits计算交叉熵损失。不熟悉sparse_softmax_cross_entropy_with_logits的读者可以参考官方文档,也可以购买作者的《深度学习理论与实战:基础篇》的第六章,里面详细的介绍了Tensorflow的基础知识。

返回model_fn

  def model_fn(features, labels, mems, is_training):
    # 前面我们在这个地方进入
    total_loss, new_mems, monitor_dict = function_builder.get_loss(
        FLAGS, features, labels, mems, is_training)

    #### 检查模型的参数
    num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('#params: {}'.format(num_params))

    # GPU
    assert is_training
    # 得到所有可以训练的变量
    all_vars = tf.trainable_variables()
    # 计算梯度
    grads = tf.gradients(total_loss, all_vars)
    # 把梯度和变量组成pair
    grads_and_vars = list(zip(grads, all_vars))
    
    return total_loss, new_mems, grads_and_vars

返回train函数

  for i in range(FLAGS.num_core_per_host):
    reuse = True if i > 0 else None
    with tf.device(assign_to_gpu(i, ps_device)), \
        tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

      # The mems for each tower is a dictionary
      mems_i = {}
      if FLAGS.mem_len:
        mems_i["mems"] = create_mems_tf(bsz_per_core)
      # 我们从这里进去,然后返回
      loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
          is_training=True,
          features=examples[i],
          mems=mems_i)
      # 下面都是多GPU训练相关的,把多个GPU的结果放到一起,我们暂时忽略
      tower_mems.append(mems_i)
      tower_losses.append(loss_i)
      tower_new_mems.append(new_mems_i)
      tower_grads_and_vars.append(grads_and_vars_i)

  ## 多个GPU的loss平均,我们这里只有一个GPU,走else分支
  if len(tower_losses) > 1:
    loss = tf.add_n(tower_losses) / len(tower_losses)
    grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
  else:
    loss = tower_losses[0]
    grads_and_vars = tower_grads_and_vars[0]

  ## 得到训练的operation
  train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
      grads_and_vars=grads_and_vars)
  global_step = tf.train.get_global_step()

  ##### 训练循环
  # 初始化mems
  tower_mems_np = []
  for i in range(FLAGS.num_core_per_host):
    mems_i_np = {}
    for key in tower_mems[i].keys():
      # 这个函数把返回一个list(长度为层数6),每一个元素都是0,shape是(96, 8, 1024)
      mems_i_np[key] = initialize_mems_np(bsz_per_core)
    tower_mems_np.append(mems_i_np)

  # Saver
  saver = tf.train.Saver()

  gpu_options = tf.GPUOptions(allow_growth=True)

  # 从checkpoit恢复模型参数,我们这里是从零开始Pretraining,因此没有任何可以恢复的
  # 后面Fine-tuning再介绍这个函数
  model_utils.init_from_checkpoint(FLAGS, global_vars=True)

  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
      gpu_options=gpu_options)) as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())
    # session.run时的返回值
    fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]

    total_loss, prev_step = 0., -1
    while True:
      feed_dict = {}
      # 需要把mems feed进去,它来自下面sess.run返回的tower_mems_np。
      for i in range(FLAGS.num_core_per_host):
        for key in tower_mems_np[i].keys():
          for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
            feed_dict[m] = m_np
      # 进行训练
      fetched = sess.run(fetches, feed_dict=feed_dict)
      # 得到loss、新的mems(作为下一个的输入)和当前训练的step数
      loss_np, tower_mems_np, curr_step = fetched[:3]
      total_loss += loss_np

      if curr_step > 0 and curr_step % FLAGS.iterations == 0:
        curr_loss = total_loss / (curr_step - prev_step)
        tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
            "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
            curr_step, fetched[-3], fetched[-2],
            curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
        total_loss, prev_step = 0., curr_step

      if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
        # 每隔FLAGS.save_steps(10,000)保存一下模型
        save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
        saver.save(sess, save_path)
        tf.logging.info("Model saved in path: {}".format(save_path))

      if curr_step >= FLAGS.train_steps:
        break

get_train_op函数


def get_train_op(FLAGS, total_loss, grads_and_vars=None):
  global_step = tf.train.get_or_create_global_step()

  # 线性的增加学习率,我们这里走else分支
  if FLAGS.warmup_steps > 0:
    warmup_lr = (tf.cast(global_step, tf.float32)
                 / tf.cast(FLAGS.warmup_steps, tf.float32)
                 * FLAGS.learning_rate)
  else:
    warmup_lr = 0.0

  # 学习率的decay
  if FLAGS.decay_method == "poly":
    # 多项式的学习率decay,我们这里不介绍,读者可以参考官方文档 
    decay_lr = tf.train.polynomial_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        end_learning_rate=FLAGS.learning_rate * FLAGS.min_lr_ratio)
  elif FLAGS.decay_method == "cos":
    decay_lr = tf.train.cosine_decay(
        FLAGS.learning_rate,
        global_step=global_step - FLAGS.warmup_steps,
        decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
        alpha=FLAGS.min_lr_ratio)
  else:
    raise ValueError(FLAGS.decay_method)
  # 如果step< warmup_steps则使用warmup_lr,否则使用decay_lr
  learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                           warmup_lr, decay_lr)

  if FLAGS.weight_decay == 0:
    # 如果没有weight_decay,那么使用AdamOptimizer,我们走这个分支
    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate,
        epsilon=FLAGS.adam_epsilon)
  elif FLAGS.weight_decay > 0 and FLAGS.num_core_per_host == 1:
    optimizer = AdamWeightDecayOptimizer(
        learning_rate=learning_rate,
        epsilon=FLAGS.adam_epsilon,
        exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
        weight_decay_rate=FLAGS.weight_decay)
  else:
    raise ValueError("Do not support `weight_decay > 0` with multi-gpu "
                     "training so far.")

  if FLAGS.use_tpu:
    optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

  if grads_and_vars is None:
    grads_and_vars = optimizer.compute_gradients(total_loss)
  gradients, variables = zip(*grads_and_vars)
  # 把梯度clip到最大绝对值为FLAGS.clip(0.25)
  clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip)
  
  # 这个分支没走,这里跳过 
  if getattr(FLAGS, "lr_layer_decay_rate", 1.0) != 1.0:
    n_layer = 0
    for i in range(len(clipped)):
      m = re.search(r"model/transformer/layer_(\d+?)/", variables[i].name)
      if not m: continue
      n_layer = max(n_layer, int(m.group(1)) + 1)

    for i in range(len(clipped)):
      for l in range(n_layer):
        if "model/transformer/layer_{}/".format(l) in variables[i].name:
          abs_rate = FLAGS.lr_layer_decay_rate ** (n_layer - 1 - l)
          clipped[i] *= abs_rate
          tf.logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
              abs_rate, l, variables[i].name))
          break
  
  # optimizer应用梯度的训练op
  train_op = optimizer.apply_gradients(
      zip(clipped, variables), global_step=global_step)

  # Manually increment `global_step` for AdamWeightDecayOptimizer
  if isinstance(optimizer, AdamWeightDecayOptimizer):
    new_global_step = global_step + 1
    train_op = tf.group(train_op, [global_step.assign(new_global_step)])

  return train_op, learning_rate, gnorm
发布了104 篇原创文章 · 获赞 97 · 访问量 26万+

猜你喜欢

转载自blog.csdn.net/weixin_37947156/article/details/100315836
今日推荐