VIT attention实现(paddle2.2)

# ViT Online Class
# Author: Dr. Zhu
# Project: PaddleViT (https://github.com/BR-IDL/PaddleViT)
# 2021.11
import paddle
import paddle.nn as nn

paddle.set_device('cpu')

class Attention(nn.Layer):
    # TODO: 补全时,删除pass
    def __init__(self, embed_dim, num_heads, qkv_bias=False, qk_scale=None, dropout=0., attention_dropout=0.):
        super().__init__()
        self.num_heads = num_heads 
        self.attn_head_size = int(embed_dim / self.num_heads)
        self.all_head_size = self.attn_head_size * self.num_heads
        self.qkv = nn.Linear(embed_dim, self.all_head_size*3)
        if qk_scale == None:
            self.scales = self.attn_head_size ** -0.5
        else:
            self.scales = qk_scale
        self.proj = nn.Linear(self.all_head_size, embed_dim)
        self.attn_dropout = nn.Dropout(attention_dropout)
        self.proj_dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(axis=-1)

    def transpose_multihead(self, x):
        new_shape = x.shape[:-1] + [self.num_heads, self.attn_head_size]
        x = x.reshape(new_shape)
        x = x.transpose([0, 2, 1, 3])
        return x
    
    def forward(self, x):
        qkv = self.qkv(x).chunk(3, axis=-1)
        q, k, v = map(self.transpose_multihead, qkv)
        attn = paddle.matmul(q, k, transpose_y=True)
        attn = attn * self.scales
        attn = self.softmax(attn)
        attn_weights = attn
        attn = self.attn_dropout(attn)
        z = paddle.matmul(attn, v)
        z = z.transpose([0, 2, 1, 3])
        new_shape = z.shape[:-2] + [self.all_head_size]
        z = z.reshape(new_shape)
        z = self.proj(z)
        z = self.proj_dropout(z)
        return z, attn_weights

def main():
    t = paddle.randn([4, 16, 96])
    print('input shape = ', t.shape)

    model = Attention(embed_dim=96, num_heads=8, 
                      qkv_bias=False, qk_scale=None, dropout=0., attention_dropout=0.)
    print(model)

    out, attn_weights = model(t)
    print(out.shape)
    print(attn_weights.shape)


if __name__ == "__main__":
    main()

猜你喜欢

转载自blog.csdn.net/lanmengyiyu/article/details/121640829
ViT