Transformer t5代码解读3

Transformer t5模型代码解读3


回到T5Attention的内容之中
上面的调用self._relative_position_bucket函数是在compute_bias函数之中,这里我们重新梳理一下T5Attention的compute_bias的函数内容

context_position = torch.arange(
    query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[:, None]
memory_position = torch.arange(
    key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[None, :]
relative_position = memory_position - context_position  # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
    relative_position,  # shape (query_length, key_length)
    bidirectional=(not self.is_decoder),
    num_buckets=self.relative_attention_num_buckets,
)

这里的context_position内容

context_position = torch.arange(
            query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
        )[:, None]

输入的context_position的内容为

context_position = 
tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14]])

memory_position的内容为

memory_position = 
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]])

接下来两者相减操作

relative_position = memory_position-context_position

得到relative_position的结果

relative_position = 
---relative_position = ---
tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13, 14],
        [ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13],
        [ -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,   12],
        [ -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,    11],
        [ -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,    10],
        [ -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,     9],
        [ -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,     8],
        [ -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,    7],
        [ -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,    6],
        [ -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5],
        [-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,  4],
        [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2, 3],
        [-12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1, 2],
        [-13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0, 1],
        [-14, -13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1, 0]])

调用relative_position_bucket获得的结果

relative_position_bucket = self._relative_position_bucket(
    relative_position,  # shape (query_length, key_length)
    bidirectional=(not self.is_decoder),
    num_buckets=self.relative_attention_num_buckets,
)

获得的输出的结果内容

relative_position_bucket = 
tensor([[ 0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25, 25, 25],
        [ 1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25, 25],
        [ 2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25],
        [ 3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24],
        [ 4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24],
        [ 5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24, 24],
        [ 6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23, 24],
        [ 7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22, 23],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21, 22],
        [ 8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20, 21],
        [ 8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19, 20],
        [ 8,  8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18, 19],
        [ 9,  8,  8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17, 18],
        [ 9,  9,  8,  8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0, 17],
        [ 9,  9,  9,  8,  8,  8,  8,  7,  6,  5,  4,  3,  2,  1,  0]])

接下来调用self.relative_attention_bias的内容

values = self.relative_attention_bias(relative_position_bucket)

这里先查看self.relative_attention_bias的定义

self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets,self.n_heads)

这里的参数

self.relative_attention_num_buckets = 32
self.n_heads = 8

这里的relative_attention_num_buckets相当于总的数据编号序列(联想到词汇embedding之中的vocab_size的大小)

里的词汇序列都在0~31范围之内,所以可以经过网络层结构,经过网络层结构之后

values.size = (15,15,8)

接下来调用values的旋转形状

values = values.permute([2,0,1]).unsqueeze(0)

获得values的形状

torch.Size([15, 15, 8])

最后values进行旋转

values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)

这样出来的values的形状为

values = (1,8,15,15)

最后两个15,15是连续的15
接下来的past_key_value is not None是只有decoder部分才会被调用的,encoder部分不会被调用

if past_key_value is not None:
    print('past_key_value is not None')
    position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

mask的部分接着加在position_bias之上

if mask is not None:
	position_bias = position_bias + mask

得到position_bias的内容(这里充分说明了position_bias需要跟mask相加在一起,得到新的position_bias的内容)
然后进行softmax操作(注意t5的attention模型之中没有除以根号dk的操作)

scores += position_bias
attn_weights = nn.functional.softmax(scores.float(),dim=-1).type_as(scores)

接下来进行dropout的操作

attn_weights = nn.functional.dropout(
	attn_weights,p=self.dropout,training=self.training
)

这里调用一个dropout的函数,

self.dropout = 0.1
self.training = False

接下来调用layer_head_mask的内容,layer_head_mask为None,这里的内容跳过

if layer_head_mask is not None:
	attn_weights = attn_weights*layer_head_mask

然后调用乘上value_states的操作

attn_output = unshape(torch.matmul(attn_weights,value_states))

这里两向量相乘之后得到形状

(1,8,15,64)*(1,8,15,15) = (1,8,15,64)

接下来reshape之后得到结果

(1,8,15,64)->(1,15,512)

最后调用输出的维度

attn_output = self.o(attn_output)

由原先的(1,15,512)变成现在的(1,15,512)
接下来使用最后的outputs的内容叠加

outputs = (attn_output,)+(present_key_value_state,)+(position_bias,)

这里是形成一个tuple的元组,其中attn_output为最终这一轮的输出(1,15,512),present_key_value_state = None,position_bias = (1,8,15,64),最终三者合并在一起构成最终的tuple内容。
最后调用

if output_attentions:
	outputs = outputs+(attn_weights,)

总结

t5attentioin与常规的attention模型的不同主要表现在三个方面:
1.不需要除以根号dk
2.相对位置编码计算不同(加上mask的部分一样)
3.最后计算完成之后还需要通过一个线性层进行输出

Guess you like

Origin blog.csdn.net/znevegiveup1/article/details/121411678