pytorch コードはアテンション メカニズムの GE を実装します

GE アテンションのメカニズム

GE アテンション メカニズム (正式名は Gather-Excite アテンション) は、「Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks」に由来しています。

要約: 畳み込みニューラル ネットワーク (CNN) でボトムアップのローカル演算子を使用すると、自然画像の統計の一部と厳密に一致しますが、そのようなモデルがコンテキスト上の長距離特徴相互作用をキャプチャできなくなる可能性もあります。この研究では、CNN でコンテキストをより適切に利用するためのシンプルで軽量な方法を提案します。これは、大規模な空間範囲からのフィーチャ応答を効率的に集約するcollectと、プールされた情報をローカル フィーチャに再分配するexci​​teのペアの演算子を導入することで実現します。この演算子は、追加パラメーターの数と計算の複雑さの両方の点で安価であり、既存のアーキテクチャに直接統合してパフォーマンスを向上させることができます。複数のデータセットに対する実験では、クラスター化された励起が、わずかなコストで CNN の深度の増加に匹敵する利点をもたらすことができることを示しています。たとえば、Gather-Excite オペレーターを備えた ResNet-50 は、学習可能なパラメーターを追加しなくても、101 層の対応する ImageNet を上回るパフォーマンスを発揮できることがわかりました。また、さらなるパフォーマンス向上をもたらすパラメーター収集とインセンティブ演算子のペアを提案し、それを最近導入されたスクイーズおよびインセンティブ ネットワークにリンクし、これらの変更が CNN 機能アクティベーション統計に及ぼす影響を分析します。

論文アドレス: Gather-Excite: 畳み込みニューラル ネットワークにおける機能コンテキストの活用

GE処理構造図

コード:

import math, torch
from torch import nn as nn
import torch.nn.functional as F

from timm.models.layers.create_act import create_act_layer, get_act_layer
from timm.models.layers.create_conv2d import create_conv2d
from timm.models.layers import make_divisible
from timm.models.layers.mlp import ConvMlp


class GatherExcite(nn.Module):
    def __init__(
            self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
            rd_ratio=1./16, rd_channels=None,  rd_divisor=1, add_maxpool=False,
            act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
        super(GatherExcite, self).__init__()
        self.add_maxpool = add_maxpool
        act_layer = get_act_layer(act_layer)
        self.extent = extent
        if extra_params:
            self.gather = nn.Sequential()
            if extent == 0:
                assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
                self.gather.add_module(
                    'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
                if norm_layer:
                    self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
            else:
                assert extent % 2 == 0
                num_conv = int(math.log2(extent))
                for i in range(num_conv):
                    self.gather.add_module(
                        f'conv{
    
    i + 1}',
                        create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
                    if norm_layer:
                        self.gather.add_module(f'norm{
    
    i + 1}', nn.BatchNorm2d(channels))
                    if i != num_conv - 1:
                        self.gather.add_module(f'act{
    
    i + 1}', act_layer(inplace=True))
        else:
            self.gather = None
            if self.extent == 0:
                self.gk = 0
                self.gs = 0
            else:
                assert extent % 2 == 0
                self.gk = self.extent * 2 - 1
                self.gs = self.extent

        if not rd_channels:
            rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
        self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
        self.gate = create_act_layer(gate_layer)

    def forward(self, x):
        size = x.shape[-2:]
        if self.gather is not None:
            x_ge = self.gather(x)
        else:
            if self.extent == 0:
                # global extent
                x_ge = x.mean(dim=(2, 3), keepdims=True)
                if self.add_maxpool:
                    # experimental codepath, may remove or change
                    x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
            else:
                x_ge = F.avg_pool2d(
                    x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
                if self.add_maxpool:
                    # experimental codepath, may remove or change
                    x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
        x_ge = self.mlp(x_ge)
        if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
            x_ge = F.interpolate(x_ge, size=size)
        return x * self.gate(x_ge)

if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    GE = GatherExcite(512)
    output=GE(input)
    print(output.shape)

おすすめ

転載: blog.csdn.net/DM_zx/article/details/132732058