transformer t5代码解读2

先整体看relative_position的调用说明

The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. (如果双向=False,正相对位置是无效的。)

		We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. 
		我们对于小的相对位置使用小的buckets,对于大的相对位置使用大的buckets。
		All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
		所有相对位置大于max_distance的映射到相同的bucket,所有相对位置小于等于max_distance的映射到相同的bucket。
        This should allow for more graceful generalization to longer sequences than the model has been trained on

对于相对位置编码的综合的解释

这个设计的思路其实也很直观,就是比较邻近的位置(0~7),我们需要比较得精细一些,所以给它们都分配一个独立的位置编码,至于稍远的位置(比如8~11),我们不用区分得太清楚,所以它们可以共用一个位置编码,距离越远,共用的范围就可以越大,直到达到指定范围再clip。使用类似于nezha的相对位置编码,考虑i-j的相对位置内容信息。

T5Attention类别的调用

首先查看一下初始化参数的值

self.has_relative_attention_bias = False
self.relative_attention_num_buckets = 32
self.d_model = 512
self.key_value_proj_dim = 64
self.inner_dim = 512

接着进入forward调用程序部分

batch_size,seq_length = hidden_states.shape[:2]
real_seq_length = seq_length

得到参数

batch_size = 1,seq_length = 15,real_seq_length = 15

然后经历三个dense网络层

query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

这里对应的形状参数为

query_states = 
torch.Size([1, 8, 15, 64])
......

接下来本质上是跟nezha模型一样,有个position_bias参数,然后softmax(Q*K+bias)…
这里T5中的selfattention公式与之前的selfattention公式有所区别,T5之中的selfattention公式内容为
s o f t m a x ( Q ∗ K T + p o s i t i o n b i a s ) ∗ V softmax(Q*K^{T}+position_{bias})*V softmax(QKT+positionbias)V
这里不需要除以 d k d_{k} dk
接下来计算scores的内容

scores = torch.matmul(
	query_states,key_states.transpose(3,2)
)

得到的scores的内容是

scores = (1,8,15,15)

接下来计算位置便移的position_bias内容

position_bias = self.compute_bias(real_seq_length,key_length)

进入的到compute_bias函数之中,查看它的对应计算过程

t5模型compute_bias调用过程

(感觉这里t5模型的compute_bias与nezha中的compute_bias有点类似???)
注意!!!t5的相对位置编码和nezha的相对位置编码还是不一样的!!!

def compute_bias(self, query_length, key_length):
    """Compute binned relative position bias"""
    #query_length = 15,key_length = 15
    context_position = torch.arange(
        query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
    )[:, None]
    memory_position = torch.arange(
        key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
    )[None, :]
    relative_position = memory_position - context_position  # shape (query_length, key_length)
    relative_position_bucket = self._relative_position_bucket(
        relative_position,  # shape (query_length, key_length)
        bidirectional=(not self.is_decoder),
        num_buckets=self.relative_attention_num_buckets,
    )
    values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
    values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
    return values

这里首先通过语句调用context_position和memory_position
调用语句

context_position = torch.arange(
    query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[:, None]
memory_position = torch.arange(
    key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[None, :]

得到相应的参数

context_position = 
tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14]])
memory_position = 
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]])

接下来计算relative_position的位置内容

relative_position = memory_position-context_position

得到对应的相对位置内容

relative_position = relative_position = 
[[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
   14],
 [ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
   13],
 [ -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,
   12],
 [ -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,
   11],
 [ -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
   10],
 [ -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,
    9],
 [ -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,
    8],
 [ -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,
    7],
 [ -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,
    6],
 [ -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,
    5],
 [-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,
    4],
 [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
    3],
 [-12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,
    2],
 [-13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,
    1],
 [-14, -13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,
    0]])

然后计算relative_position_bucket的向量内容

relative_position_bucket = self._relative_position_bucket(
    relative_position,  # shape (query_length, key_length)
    bidirectional=(not self.is_decoder),
    num_buckets=self.relative_attention_num_buckets,
)

进入到_relative_position_bucket函数之中,首先进行调用

if bidirectional:
    num_buckets //= 2
    relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
    relative_position = torch.abs(relative_position)
else:
    relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))

这里的bidirectional代表双向的含义是左右位置都处于相同的状态(举个例子,左边0~7的位置和右边0~7的位置的编码相同)
这里调用bidirectional的部分,首先调用num_bucket的参数

num_bucket = 32

然后将num_bucket除以2

num_buckets //= 2

求得num_buckets的对应值

num_buckets = 16

接下来调用上面的relative_position并对非零部分乘上num_buckets

relative_buckets += (relative_position > 0).to(torch.long)*num_buckets

得到relative_buckets的对应值

relative_buckets = 
tensor([[ 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

接下来的relative_position还是接着上面的relative_position进行操作,与上面的relative_buckets的内容暂时没有关系

relative_position = torch.abs(relative_position)

获得relative_position的对应值

relative_position = 
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
        [ 3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [ 6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6],
        [ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5],
        [10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4],
        [11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3],
        [12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2],
        [13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1],
        [14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0]])

然后定义is_small的内容

is_small = relative_position < max_exact

得到对应的is_small的内容
is_small的对应内容可以看出,这里的is_small是一种流线型的bool判定矩阵内容
接下来是一部分较为复杂的计算过程(这里牵扯到t5相对位置编码的计算)

relative_postion_if_large = max_exact + (
    torch.log(relative_position.float() / max_exact)
    / math.log(max_distance / max_exact)
    * (num_buckets - max_exact)
).to(torch.long)

一点一点解析这里公式的内容
首先这里的max_exact = 8,将全部的max_exact用8替换

relative_position_if_large = 8 + log(relative_position/8)/log(max_distance/8)

这里对应的相对位置矩阵

relative_position = 
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
        [ 3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [ 6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5,  6],
        [ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4,  5],
        [10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3,  4],
        [11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2,  3],
        [12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1,  2],
        [13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  1],
        [14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0]])

对应的最大距离

max_distance = 
128

所以最终的计算公式为
8 + ( l o g ( [ [ 0 , 1 , 2 , . . . 14 ] ,     [ 1 , 0 , 1 , . . . 13 ]     . . . . . . . . . . . . . . . . . . .     [ 14 , 13 , . . 1 , 0 ] ] ) / 8 ) / ( 128 / 8 ) 8 + (log([[0,1,2,...14], \\ \qquad \quad \ \ \ [1,0,1,...13] \\ \qquad \quad \ \ \ ...................\\ \qquad \quad \ \ \ [14,13,..1,0]])/8)/(128/8) 8+(log([[0,1,2,...14],   [1,0,1,...13]   ...................   [14,13,..1,0]])/8)/(128/8)
化简之后的结果为
8 + ( l o g ( [ [ 0 , 1 , 2 , . . . 14 ] ,     [ 1 , 0 , 1 , . . . 13 ]     . . . . . . . . . . . . . . . . . . .     [ 14 , 13 , . . 1 , 0 ] ] ) / 128 ) 8 + (log([[0,1,2,...14], \\ \qquad \quad \ \ \ [1,0,1,...13] \\ \qquad \quad \ \ \ ...................\\ \qquad \quad \ \ \ [14,13,..1,0]])/128) 8+(log([[0,1,2,...14],   [1,0,1,...13]   ...................   [14,13,..1,0]])/128)
这里我们首先查看log的计算结果

torch.log = 
tensor([[  -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026, 2.3979, 2.4849, 2.5649, 2.6391],
        [0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026, 2.3979, 2.4849, 2.5649],
        [0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026, 2.3979, 2.4849],
        [1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026, 2.3979],
        [1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026],
        [1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972],
        [1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794],
        [1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459],
        [2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918],
        [2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094],
        [2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986, 1.3863],
        [2.3979, 2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931, 1.0986],
        [2.4849, 2.3979, 2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000, 0.6931],
        [2.5649, 2.4849, 2.3979, 2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf, 0.0000],
        [2.6391, 2.5649, 2.4849, 2.3979, 2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000,   -inf]])

接下来查看

torch.log(relative_position.float()/max_exact)

的对应的内容

torch.log = 
tensor([[  -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878, 0.2997, 0.3106, 0.3206, 0.3299],
        [0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878, 0.2997, 0.3106, 0.3206],
        [0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878, 0.2997, 0.3106],
        [0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878, 0.2997],
        [0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878],
        [0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747],
        [0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599],
        [0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432],
        [0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240],
        [0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012],
        [0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373, 0.1733],
        [0.2997, 0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866, 0.1373],
        [0.3106, 0.2997, 0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000, 0.0866],
        [0.3206, 0.3106, 0.2997, 0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf, 0.0000],
        [0.3299, 0.3206, 0.3106, 0.2997, 0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000,   -inf]])

接着查看除数中的max_distance/max_exact

math.log(max_distance/max_exact)

这里的max_distance = 128,max_exact = 8,所以这里计算出来的结果

math.log(128/8) = math.log(16) = 2.0794415416798357

最后部分调用relative_position_if_large部分

relative_position_if_large = torch.min(
	relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)

疑问??
这里一直有个疑问,就是在这个位置输出relative_position_if_large为什么一直报错

relative_position_bucket = self._relative_position_bucket(
    410             relative_position,  # shape (query_length, key_length)
    411             bidirectional=(not self.is_decoder),

~/.local/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py in _relative_position_bucket(relative_position, bidirectional, num_buckets, max_distance)
    393         )
    394         print('relative_position_if_large = ')
--> 395         print(relative_position_if_large)
    396         relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)

t5对应的relative_position公式

这里为什么输出不了先不管他(后来发现还是自己的名称写错了),我们直接换一个变量进行输出内容

current_data = max_exact + (
    torch.log(relative_position.float() / max_exact)
    / math.log(max_distance / max_exact)
    * (num_buckets - max_exact)
)

最终的经过公式之后输出的内容为
经过公式输出的内容仔细观察可疑看出,这里近距离的内容,比如周围1,2部分的内容,比如数值2,变化的很快,而远一些的位置,比如数值8就会延续很长一段数值,变化的很慢

这里的num_buckets = 16,max_exact = 8,因此16-8 = 8.
接下来调用的过程

relative_position_if_large = torch.min(
    relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)

relative_position_if_large相对位置

relative_position_if_large = torch.min(
	relative_position_if_large,torch.full_like(relative_position_if_large,num_buckets-1)
)

这里先看一下torch.full_like的调用

torch.full_like(relative_position_if_large.num_buckets-1)

得到的结果

tensor([[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
        [15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15]])

接下来调用torch.min的函数内容

relative_position_if_large = torch.min(
	relative_position_if_large,torch.full_like(relative_position_if_large,num_buckets-1)
)

这波操作在这里没有任何的变动,relative_position_if_large的对应值还是上文的对应值
relative_position_if_large的对应值原因在于relative_position_if_large的所有值都要比15小,所以这里本质上等于没变
接下来调用最后一波操作:

relative_buckets += torch.where(is_small,relative_position,relative_position_if_large)

注释:
这里首先挂出来原版的relative_buckets(上文调用的过程)

relative_buckets += (relative_position > 0).to(torch.long)*num_buckets

注释完毕
原版的relative_buckets的内容

relative_buckets = 
tensor([[ 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 16],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

现在需要调用的函数内容

torch.where(is_small,relative_position,relative_position_if_large)

调用结束的内容为

torch.where = 
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9, 9, 9],
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9, 9],
        [2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9],
        [3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8],
        [4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8],
        [5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8],
        [6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8],
        [7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7],
        [8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6],
        [8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5],
        [8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4],
        [8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3],
        [9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2],
        [9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1],
        [9, 9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0]])

这里解析一下内容
is_small函数选取的是较小的内容,

is_small = relative_position < max_exact

也就是说这里如果relative_position < max_exact的时候,选取relative_position的对应值内容即可,否则选取relative_if_large的内容,得到的torch.where矩阵内容为

torch.where = 
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9, 9, 9],
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9, 9],
        [2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9],
        [3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8],
        [4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8],
        [5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8],
        [6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8],
        [7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7],
        [8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6],
        [8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5],
        [8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4],
        [8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3],
        [9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2],
        [9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1],
        [9, 9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0]])

最后加上得到的内容

relative_buckets += torch.where(is_small,relative_position,relative_position_if_large)

得到最终的内容

relative_buckets = 
tensor([[ 0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25, 25, 25],
        [ 1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25, 25],
        [ 2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25],
        [ 3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24],
        [ 4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24],
        [ 5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24],
        [ 6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24],
        [ 7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22],
        [ 8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21],
        [ 8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20],
        [ 8,  8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19],
        [ 9,  8,  8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18],
        [ 9,  9,  8,  8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17],
        [ 9,  9,  9,  8,  8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0]])

bert4keras模型源码

class RelativePositionEmbeddingT5(RelativePositionEmbedding):
    """Google T5的相对位置编码
    来自论文:https://arxiv.org/abs/1910.10683
    """
    def __init__(
        self,
        input_dim,
        output_dim,
        max_distance=128,
        bidirectional=True,
        embeddings_initializer='zeros',
        **kwargs
    ):
        super(RelativePositionEmbeddingT5,
              self).__init__(input_dim, output_dim, **kwargs)
        self.max_distance = max_distance
        self.bidirectional = bidirectional

    def compute_position_ids(self, inputs):
        """T5的相对位置分桶(直接翻译自官方T5源码)
        """
        q, v = inputs
        # 计算位置差
        q_idxs = K.arange(0, K.shape(q)[1], dtype='int32')
        q_idxs = K.expand_dims(q_idxs, 1)
        v_idxs = K.arange(0, K.shape(v)[1], dtype='int32')
        v_idxs = K.expand_dims(v_idxs, 0)
        pos_ids = v_idxs - q_idxs
        # 后处理操作
        num_buckets, max_distance = self.input_dim, self.max_distance
        ret = 0
        n = -pos_ids
        if self.bidirectional:
            num_buckets //= 2
            ret += K.cast(K.less(n, 0), 'int32') * num_buckets
            n = K.abs(n)
        else:
            n = K.maximum(n, 0)
        # now n is in the range [0, inf)
        max_exact = num_buckets // 2
        is_small = K.less(n, max_exact)
        val_if_large = max_exact + K.cast(
            K.log(K.cast(n, K.floatx()) / max_exact) /
            np.log(max_distance / max_exact) * (num_buckets - max_exact),
            'int32',
        )
        val_if_large = K.minimum(val_if_large, num_buckets - 1)
        ret += K.switch(is_small, n, val_if_large)
        return ret

猜你喜欢

转载自blog.csdn.net/znevegiveup1/article/details/121344398
今日推荐