pytorch code implements ParNet_Attention of attention mechanism

ParNet attention mechanism

ParNet attention is an attention mechanism for natural language processing tasks, which was proposed by Google in 2019. ParNet attention aims to solve the efficiency problem of traditional attention mechanism when dealing with long sequences.

When the traditional attention mechanism calculates the attention weight, it needs to calculate the positions of all input sequences one by one, which leads to high computational complexity on long sequences. While ParNet attention reduces the computational complexity by splitting the sequence into multiple subsequences and performing independent attention calculations on each subsequence.

Paper address: https://arxiv.org/pdf/2110.07641.pdf

Structural schematic

code show as below:

import numpy as np
import torch
from torch import nn
from torch.nn import init

class ParNetAttention(nn.Module):

    def __init__(self, channel=512):
        super().__init__()
        self.sse = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channel, channel, kernel_size=1),
            nn.Sigmoid()
        )

        self.conv1x1 = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=1),
            nn.BatchNorm2d(channel)
        )
        self.conv3x3 = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(channel)
        )
        self.silu = nn.SiLU()

    def forward(self, x):
        b, c, _, _ = x.size()
        x1 = self.conv1x1(x)
        x2 = self.conv3x3(x)
        x3 = self.sse(x) * x
        y = self.silu(x1 + x2 + x3)
        return y

if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    pna = ParNetAttention(channel=512)
    output = pna(input)
    print(output.shape)

おすすめ

転載: blog.csdn.net/DM_zx/article/details/132381800