[pytorch]——torch.gather(以BERT中的MLM为例)

前言

都知道BERT中有MLM的任务,假设此时ENCODER的输出output的大小为: batch_size x max_len x d_model,而对于每一个句子,都有对应的数个被mask掉的单词,所以假设被mask掉的单词下标矩阵大小为:batch_size x mask_num。那么,我们要做的,就是在output的每一个句子中,按照下标,选择mask_num个单词,从而组成:batch_size x mask_num x d_model大小的矩阵,再通过线性层,进行分类。

具体代码如下:

mask = mask[:, :, None].expand(-1, -1, d_model) # [batch_size, mask_num, d_model]
output2 = torch.gather(input=output, dim=1, index=mask ) # masking position [batch_size, max_pred, d_model]

解释

不妨举个简单的例子,假设

  • batch_size = 3
  • max_len = 3
  • d_model = 4
  • mask_num = 2

我们手写一个output矩阵如下:

output = torch.tensor(
    [
        [[1,2,3,4],[5,6,7,8],[9,10,11,12]],
        [[13,14,15,16],[17,18,19,20],[21,22,23,24]],
        [[25,26,27,28],[29,30,31,32],[33,34,35,36]]
    ]
)

我们再手写一个mask矩阵如下:

"""
# 意思是:
对于第0个句子,mask掉它的第0、第1个单词;
对于第1个句子,mask掉它的第1、第2个单词;
对于第1个句子,mask掉它的第0、第2个单词。
这些被mask掉的单词,就是模型需要预测的部分。
"""
mask = torch.tensor(
    [
        [0,1],
        [1,2],
        [0,2]
    ]
)

按照上面给出的代码,首先对mask进行拓展

mask = mask[:, :, None].expand(-1, -1, 4) # [batch_size, mask_num, d_model]

得到:
在这里插入图片描述

这个矩阵用作torch.gather函数的第三个(index)参数,该如何理解这个矩阵呢?——比如,对于第0个句子的第0个单词,它是被mask掉的,所以我们需要取出它的向量,而这个向量是4(d_model)维的,所以上面index矩阵中的[0,0,0,0](从上到下看第0行)的意思就是:每一维上的值,都是第0个单词的; 再想一例,对于第0个句子的第1个单词,它也是被mask掉的,所以我们也需要取出它的向量,则index矩阵中的[1,1,1,1](从上到下看第1行)的意思也就是:每一维上的值,都是来自于第1个单词的。。。

下面结合torch.gather函数来讲解:

output2 = torch.gather(output, 1, mask) # masking position [batch_size, max_pred, d_model]

torch.gather可以分为三步走:
第一步,由于参数dim=1,它的意思是,我们可以先确定所有dim=0和dim=2的下标值,然后用mask来填充dim=1处的下标值。所以首先生成:

[
    [[(?,0),(?,1),(?,2),(?,3)],
    [(?,0),(?,1),(?,2),(?,3)]],
    
    [[(?,0),(?,1),(?,2),(?,3)],
    [(?,0),(?,1),(?,2),(?,3)]],
    
    [[(?,0),(?,1),(?,2),(?,3)],
    [(?,0),(?,1),(?,2),(?,3)]]
]

第二步,使用传入的mask矩阵,填充上面的矩阵:

[
    [[(0,0),(0,1),(0,2),(0,3)],
    [(1,0),(1,1),(1,2),(1,3)]],
    
    [[(1,0),(1,1),(1,2),(1,3)],
    [(2,0),(2,1),(2,2),(2,3)]],
    
    [[(0,0),(0,1),(0,2),(0,3)],
    [(2,0),(2,1),(2,2),(2,3)]]
]

第三步,利用该下标矩阵,去output矩阵中,按下标取出所有元素,得:

output2 = torch.gather(output, 1, mask)

在这里插入图片描述

参考

https://github.com/wmathor/nlp-tutorial/blob/master/5-2.BERT/BERT-Torch.ipynb

猜你喜欢

转载自blog.csdn.net/jokerxsy/article/details/116950445