SWin および CSWin Transformer を超える新モデル
ViT はさまざまなビジョン タスクで成功しますが、その計算コストはトークン シーケンスの長さに応じて二次関数的に増加するため、大規模な特徴マップを処理する際のパフォーマンスが大幅に制限されます。計算コストを軽減するために、これまでの研究は、小さな局所領域に限定されたきめの細かいセルフ アテンション、またはシーケンス長が短いグローバル セルフ アテンションのいずれかに依存しており、その結果、粗い問題が発生していました。
Self-Guided Transformer (SG-Former) は、適応性のある細粒度化により効果的なグローバル Self-Attendance を実現します。この方法の中心となるアイデアは、混合スケールのセルフアテンションによって推定され、トレーニング プロセス中に自己進化した顕著性マップを使用して、各領域の重要性に応じてトークンを再配布することです。直観的には、より多くのトークンが顕著な領域に割り当てられ、きめ細かい注目が得られますが、効率性とグローバルな認識フィールドと引き換えに、より少ないトークンが二次領域に割り当てられます。SG-Former は、分類、検出、セグメンテーションのタスクにおいて最先端のモデルを上回るパフォーマンスを発揮します。
高解像度の特徴に対するセルフ アテンションを計算するために、いくつかの方法では、セルフ アテンション領域を特徴マップ全体ではなくローカル ウィンドウに制限することが提案されています (すなわち、きめの細かいローカル セルフ アテンション)。たとえば、Swin Transformer はウィンドウを重視して設計されていますが、CSWin は十字形状を重視して設計されています。したがって、これらの方法では、各セルフアテンション層でグローバル情報をモデル化する機能が犠牲になります。別の一連のメソッドは、キーと値の特徴マップ全体にわたってトークンを集約して、グローバル シーケンスの長さを削減することを目的としています (つまり、粗粒度のグローバル アテンション)。たとえば、Pyramid Vision Transformer (PVT) は、大きなストライドを持つ大きなコアを使用して、特徴マップ全体にわたってトークンを均一に集約し、その結果、特徴マップ全体にわたって均一な大まかな情報が得られます。
この記事の Self-Guided Transformer (SG-Former) は、進化する Self-Attention 設計を通じて、適応型のきめ細かいグローバル アテンションを実現します。SG-Former の中心的なアイデアは、画像領域の重要性に応じてトークンを再配布しながら、特徴マップ全体の長距離依存関係を保持することです。
つまり、モデルは、各トークンが細かい粒度で顕著な領域と対話できるように、より多くのトークンを顕著な領域に割り当てますが、効率を高めるために二次領域にはより少ないトークンを割り当てます。SG-Former は、顕著な領域のきめの細かい情報に適応的に焦点を当てながら、効率的なグローバル知覚フィールドを使用してセルフ アテンションを計算します。
図 2 に示すように、SG-Former は、自身から取得したアテンション マップに基づいて、犬などの顕著な領域にはより多くのトークンを割り当て、壁などの二次的な領域にはより少ないトークンを割り当てるなど、トークンを再配布します。PVT は、トークンを均等に集約するために事前定義された戦略を採用します。
具体的には、クエリ トークンは保持されますが、キーと値のトークンは効率的なグローバル セルフ アテンションを実現するために再割り当てされます。画像領域の重要性は、スコア マップの形式で、混合スケールのセルフ アテンション自体によって推定され、トークンの再割り当てのガイドとしてさらに使用されます。
つまり、入力画像が与えられると、トークンの再割り当ては Self-Guided を通じて行われます。つまり、各画像はそれ自体にのみ適用される一意のトークンの再割り当てを受けることになります。したがって、再配布されたトークンは人間による事前の影響をあまり受けません。
さらに、この Self-Guided は、トレーニング中のアテンション マップの予測の精度がさらに高まることで進化し続けます。アテンションマップは再割り当ての有効性に大きく影響するため、Swinと同じコストで同じ層にさまざまな粒度の情報を持つ混合スケールのセルフアテンションが提案されています。混合スケールのセルフアテンションは、ヘッドをグループ化し、異なる注意の粒度に合わせて各グループを多様化することで、さまざまな粒度の情報を実現します。混合スケールのセルフアテンションは、混合スケールの情報を Transformer 全体に提供します。
この記事には次の寄稿があります。
1. ローカルおよびグローバルの粒度の細かい情報は、セルフアテンション層内の統合された混合スケール情報を通じて抽出されます。アテンション マップを予測して、統一されたローカルとグローバルのハイブリッド スケール情報を使用して地域の重要性を特定します。
2. アテンション マップを使用すると、セルフガイド アテンションをシミュレートし、顕著な領域を自動的に特定し、顕著な領域で詳細な情報を抽出できるようにしながら、二次領域で粗い情報を抽出して計算コストを削減できます。
3. 最先端のモデルと比較して、分類、ターゲット検出、セグメンテーションのタスクが大幅に向上しています。
以下、sgformer_sの構成をもとにコード部分を分解し、コードに基づいて論文を説明します。
特定のパラメータ構成は、以下の表で確認できます。
1.1、概要
SG-Former の全体的なプロセスを図 3 に示します。SG-Former は、以前の CNN および Transformer モデルと同じパッチ埋め込み層と 4 次ピラミッド アーキテクチャを共有します。
まず、画像は入力レベルのパッチ埋め込み層を通じて 4 分の 1 にダウンサンプリングされます。2 つのステージの間には 2x レートのダウンサンプリング層があります。したがって、第 3 段階の特徴マップのサイズは です。最後のステージを除く各ステージには、2 種類のブロックの繰り返しで構成される Transformer ブロックがあります。
-
混合スケール変圧器ブロック
-
自己誘導変圧器ブロック。
混合スケールのセルフ アテンションは、混合スケールのオブジェクトと複数粒度の情報を抽出して、地域の重要性をガイドします。混合スケール変圧器ブロックの重要度情報に従って顕著な領域の粒度を維持しながら、セルフガイドセルフアテンションモデルのグローバル情報。
まず、SG フォーマーの全体的なコード構造を見てみましょう (コード例では一部のパラメーターの初期化操作が省略されています)。
class SGFormer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.num_stages = num_stages
self.num_patches = img_size//4
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
for i in range(num_stages):
if i == 0:
patch_embed = Head(embed_dims[0]) #
else:
patch_embed = PatchMerging(dim=embed_dims[i - 1],
out_dim=embed_dims[i])
block = nn.ModuleList([Block(
dim=embed_dims[i], mask=True if (j%2==1 and i<num_stages-1) else False, num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer,
sr_ratio=sr_ratios[i], linear=linear)
for j in range(depths[i])])
norm = norm_layer(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"norm{i + 1}", norm)
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches*self.num_patches, embed_dims[0])) # fixed sin-cos embedding
# classification head
self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def forward_features(self, x):
B = x.shape[0]
mask=None
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
norm = getattr(self, f"norm{i + 1}")
x, H, W = patch_embed(x) # [N 3136 64] #[N 784 128] #[N 196 256] #[N 49 512]
if i==0:
x+=self.pos_embed # [1 3136 64]
for blk in block:
x, mask = blk(x, H, W, mask)
x = norm(x) # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]
if i != self.num_stages - 1:
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # [N 64 56 56] [N 128 28 28] [N 256 14 14]
return x.mean(dim=1)
def forward(self, x):
x = self.forward_features(x) # entry
x = self.head(x)
return x
画像は最初にパッチ埋め込みを通過します(コードの先頭)
class Head(nn.Module):
def __init__(self, n):
super(Head, self).__init__()
self.conv = nn.Sequential(
Conv2d_BN(3, n, 3, 2, 1),
nn.GELU(),
Conv2d_BN(n, n, 3, 1, 1),
nn.GELU(),
Conv2d_BN(n, n, 3, 2, 1),
)
self.norm = nn.LayerNorm(n)
self.apply(self._init_weights)
def forward(self, x):
x = self.conv(x)
_, _, H, W = x.shape # [N 64 56 56]
x = x.flatten(2)#.transpose(1, 2) # [N 64 3136]
x = x.transpose(1, 2)
x = self.norm(x) # [N 3136 64]
return x, H,W
実際, 3 つの 2d 畳み込みと BN の後, 2 つの GELU 活性化関数が中央に挿入されます. 2 つの連続した 2d 畳み込みのダウンサンプリングのストライドは 2 です. 入力 X=[N,3,224,224] は 4 回ダウンサンプリングされます. [N ,64,56,56] を実行し、次元変換を実行して LayerNorm を追加します。
最初の層は、パッチ埋め込みの出力に pos 埋め込み (形状は [1,3136,64]) を追加します。
中央のダウン サンプリングの埋め込みは、PatchMerging によって完了します。
class PatchMerging(nn.Module):
def __init__(self, dim, out_dim):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.act = nn.GELU()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, 2, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
# x B C H W
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
return x, H, W
主にダウンサンプリング機能として機能します。
次は、この記事の重要なモジュールである Transformer ブロックです。これは 2 種類のブロックの繰り返しで構成されます。
-
混合スケール変圧器ブロック
-
自己誘導変圧器ブロック。
基本的なトランス ブロックを含むこれら 2 つのモジュールは、ブロック関数によって構築されます。
class Block(nn.Module):
def __init__(self, dim, mask, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, mask,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def forward(self, x, H, W, mask):
x_, mask = self.attn(self.norm1(x), H, W, mask) # x[N 3136 64] mask[[N 3136],[N 3136]]
x = x + self.drop_path(x_)
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x, mask
ブロック機能で最も重要なことは注意です
1.2、 トランスブロック
2 つのセルフ アテンション メカニズムを備えた 2 種類の変圧器ブロックは、それに応じて設計されています。これら 2 つの Transformer ブロックはアテンション レイヤーのみが異なり、その他はすべて同じです。
図 3 に示すように、最初の 3 つのステージは、私たちが提案するハイブリッド スケールまたはセルフガイド変圧器ブロックを使用してカスタマイズされます。一方、最後のステージでは、以前の変圧器に基づく標準の変圧器ブロックが使用されます。最初の 3 ステージ (N1、N2、N3) の変圧器ブロックの数は偶数ですが、最後のステージ (N4) は偶数でも奇数でもよいことに注意してください。
1.3、ハイブリッドスケールの注意
混合スケールの注意には 2 つの目的があります。
-
Swin Transformer でのウィンドウ処理よりも多くの計算コストを費やすことなく、混合スケールでグローバルで詳細な情報を抽出します。
-
自発的な注意を重視する
図 5 に示すように、入力特徴 X はクエリ (Q)、キー (K)、および値 (V) に投影されます。次に、マルチヘッドセルフアテンションは、H 個の独立したヘッドを使用します。通常、これらの H 個の独立したヘッドは同じローカル エリア内で動作を実行するため、ヘッドの多様性が欠けています。
対照的に、この記事では、H の頭部を h のグループに均等に分割し、これらの h のグループに混合スケールと複数の受容野の注意を注入します。頭は全体的な注意を行い、頭の半分は局所的な注意を行います)。番目のグループに属する 番目のヘッドで、scale (ここで) を使用して、{K, V} の各トークンを 1 つのトークンにマージします。次に、{Q,K,V} をウィンドウに分割します (swin と同じ)。{K,V} のウィンドウ サイズは M に設定され、すべてのグループにわたって一定のままです。{Q} と {K,V} のウィンドウ サイズを {K,V} のトークンに合わせるために、 {Q} のウィンドウ サイズが {K,V } のウィンドウ サイズの倍数になるように選択されます。
これは、各 トークンを 1 つのトークンにマージすることを意味し、これはステップ畳み込みによって実現されます。特殊なケースは、これが 1 に等しい場合、トークンのマージは実行されず、{Q, K, V} のウィンドウが同じサイズになることです。
ここで、 はウィンドウ サイズによるウィンドウ分割を表します。アテンションマップです。特殊な場合があります。 に等しい場合、ウィンドウ セグメンテーションは必要ありません。{K, V} のすべてのトークンの後には {Q} が続き、それによってグローバルな情報抽出が実現されます。
トークンの重要性は、すべてのトークンと現在のトークンの積の合計として考慮されます。
ここで、 S は、すべてを合計することによって得られる最終的なアテンション マップであり、大域的で詳細な情報を提供するための混合スケールのガイダンスに使用されます。コードのこの部分に対応します
# global
q1 = self.q1(x).reshape(B, N, self.num_heads//2, C // self.num_heads).permute(0, 2, 1, 3) # [N 1 3136 32] # [N 2 784 32] # [N 4 196 32]
x_ = x.permute(0, 2, 1).reshape(B, C, H, W) # [N 64 56 56] # [N 128 28 28] # [N 256 14 14]
x_1 = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) # [N 49 64] sr --> conv2d(64, 64) # [N 49 128] sr --> conv2d(128, 128) # [N 49 256] sr --> conv2d(256, 256)
x_1 = self.act(self.norm(x_1)) # [N 49 64] # [N 49 128] # [N 49 256]
kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4) # [2 N 1 49 32] # [2 N 2 49 32] # [2 N 4 49 32]
k1, v1 = kv1[0], kv1[1] #B head N C [N 1 49 32] [N 2 49 32] [N 4 49 32]
attn1 = (q1 @ k1.transpose(-2, -1)) * self.scale #B head Nq Nkv [N 1 3136 49] [N 2 784 49] [N 4 196 49]
attn1 = attn1.softmax(dim=-1)
attn1 = self.attn_drop(attn1)
x1 = (attn1 @ v1).transpose(1, 2).reshape(B, N, C//2) # [N 3136 32] [N 784 64] [N 196 128]
global_mask_value = torch.mean(attn1.detach().mean(1), dim=1) # B Nk #max ? mean ? # [N 49]
global_mask_value = F.interpolate(global_mask_value.view(B,1,H//self.sr_ratio,W//self.sr_ratio),
(H, W), mode='nearest')[:, 0] # [N 56 56] [N 28 28] [N 14 14]
# local
q2 = self.q2(x).reshape(B, N, self.num_heads // 2, C // self.num_heads).permute(0, 2, 1, 3) #B head N C # [N 1 3136 32] [N 2 784 32] [N 4 196 32]
kv2 = self.kv2(x_.reshape(B, C, -1).permute(0, 2, 1)).reshape(B, -1, 2, self.num_heads // 2,
C // self.num_heads).permute(2, 0, 3, 1, 4)# [2 N 1 3136 32] [2 N 2 784 32] [2 N 4 196 32]
k2, v2 = kv2[0], kv2[1] # [N 1 3136 32] [N 2 784 32] [N 4 196 32]
q_window = 7
window_size= 7
q2, k2, v2 = window_partition(q2, q_window, H, W), window_partition(k2, window_size, H, W), \
window_partition(v2, window_size, H, W) # [N*64,49,32] [N*32,49,32] [N*16 49 32]
attn2 = (q2 @ k2.transpose(-2, -1)) * self.scale # [N*64 49 49] [N*32 49 49] [N*16 49 49]
# (B*numheads*num_windows, window_size*window_size, window_size*window_size)
attn2 = attn2.softmax(dim=-1)
attn2 = self.attn_drop(attn2)
x2 = (attn2 @ v2) # B*numheads*num_windows, window_size*window_size, C .transpose(1, 2).reshape(B, N, C) # [N*64 49 32] [N*32 49 32] [N*16 49 32]
x2 = window_reverse(x2, q_window, H, W, self.num_heads // 2) # [N 3136 32] [N 784 64] [N 196 128]
local_mask_value = torch.mean(attn2.detach().view(B, self.num_heads//2, H//window_size*W//window_size, window_size*window_size, window_size*window_size).mean(1), dim=2) #[N 64 49]
local_mask_value = local_mask_value.view(B, H // window_size, W // window_size, window_size, window_size) # [N 8 8 7 7]
local_mask_value=local_mask_value.permute(0, 1, 3, 2, 4).contiguous().view(B, H, W) # [N 56 56] [N 28 28] [N 14 14]
# mask B H W
x = torch.cat([x1, x2], dim=-1) # [N 3136 64] [N 784 128] [N 196 256]
x = self.proj(x+lepe) # linear(64,64) # linear(128,128) # linear(256,256)
x = self.proj_drop(x)
# cal mask
mask = local_mask_value+global_mask_value # [N 56 56] [N 28 28] [N 14 14]
mask_1 = mask.view(B, H * W) # [N 3136] [N 784] [N 196]
mask_2 = mask.permute(0, 2, 1).reshape(B, H * W) # [N 3136] [N 784] [N 196]
mask = [mask_1, mask_2]
1.4、自発的な注意
セルフ アテンション モデルは広範囲の情報をモデル化できますが、その高い計算コストとメモリ消費量はシーケンスの長さの 2 乗に比例するため、セグメンテーションやセグメンテーションなどのさまざまなコンピュータ ビジョン タスクにおける大きなサイズの機能への使用は制限されます。検出、図の適用性。最近の研究では、複数のトークンを 1 つにマージすることでシーケンスの長さを短縮することが示唆されています。ただし、この集約では、異なるトークン間の固有の重要性の違いを無視して、各トークンを同等に扱います。この集約には 2 つの問題があります。
-
顕著な領域では、情報が失われたり、無関係な情報が混在したりする可能性があります
-
セカンダリ領域またはバックグラウンド領域では、単純なセマンティクスには多数のトークンが冗長ですが、多くの計算が必要です
この観察に触発されて、私たちはトークンを集約するためのガイドとして重要度を使用する自己誘導型注意を提案します。言い換えれば、顕著な領域では、きめの細かい情報を取得するためにより多くのトークンが保持されますが、二次領域では、自己注意の全体的なビューを維持し、同時に計算コストを削減するために、より少ないトークンが保持されます。
図 4 に示すように、「自己誘導型」とは、Swin のウィンドウ アテンション、CSWin のクロスシェイプ アテンション、CSWin の静的空間削減などの人為的に導入された事前知識ではなく、Transformer 自体がトレーニング中に計算コスト削減戦略を決定することを意味します。 PVT。
入力特徴マップは、まずクエリ (Q)、キー (K)、および値 (V) に射影されます。次に、H 個の独立したセルフ アテンション ヘッドが並行してセルフ アテンションを計算します。Self-Attention 後の特徴マップのサイズを変えずに計算コストを削減するために、Q の長さは固定されていますが、K と V のトークンを集約するために重要度に基づく集約モジュール (IAM) が使用されます。
IAM の目標は、顕著な領域ではより少ないトークンを 1 つに集約し (つまり、より多くのトークンを保持し)、背景領域ではより多くのトークンを 1 つに集約する (つまり、より少ないトークンを保持する) ことです。式 (1) では、アテンション マップには複数の粒度の領域重要度情報が含まれています。
アテンション マップの値を昇順に並べ替え、S を n 個のサブ領域に均等に分割します。したがって、と はそれぞれ最も重要な領域と二次的な領域です。同時に、すべてのトークンを にグループ化します。式(1)では、重要度の異なるエリアの集約率を、各サブエリアごとに集約率を持たせて表現しており、重要なサブエリアほど集約率が小さくなるようにしている。さまざまな段階の具体的な値を表 1 に示します。したがって、IAM は、各グループの異なる集約率を接続することによって、各グループのグループ化された入力特徴のトークンを再配布します。
ここで、 は集約関数であり、入力次元 r と出力次元 1 の全結合層を通じて実装します。のトークンの数は、 のトークンの数を のトークンの数で割ったものに等しくなります。コードのこの部分に対応します
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # [N 2 3136 32] [N 4 784 32] [N 8 196 32]
# mask [local_mask global_mask] local_mask [value index] value [B, H, W]
# use mask to fuse
mask_1, mask_2 = mask # [[N 3136],[N 3136]] [[N 784],[N 784]] [[N 196],[N 196]]
mask_sort1, mask_sort_index1 = torch.sort(mask_1, dim=1)
mask_sort2, mask_sort_index2 = torch.sort(mask_2, dim=1)
if self.sr_ratio == 8:
token1, token2, token3 = H * W // (14 * 14), H * W // 56, H * W // 28 # [16 56 112]
token1, token2, token3 = token1 // 4, token2 // 2, token3 // 4 # [4 28 28]
elif self.sr_ratio == 4:
token1, token2, token3 = H * W // 49, H * W // 14, H * W // 7 # [16 56 112]
token1, token2, token3 = token1 // 4, token2 // 2, token3 // 4 # [4 28 28]
elif self.sr_ratio == 2:
token1, token2 = H * W // 2, H * W // 1 # [98 196]
token1, token2 = token1 // 2, token2 // 2 # [49 98]
if self.sr_ratio==4 or self.sr_ratio==8:
p1 = torch.gather(x, 1, mask_sort_index1[:, :H * W // 4].unsqueeze(-1).repeat(1, 1, C)) # B, N//4, C # [N 784 64] 根据mask中的index对x[:, :H * W // 4(784)]进行重新排序 [N 196 128]
p2 = torch.gather(x, 1, mask_sort_index1[:, H * W // 4:H * W // 4 * 3].unsqueeze(-1).repeat(1, 1, C)) # [N 1568 64] [N 392 128]
p3 = torch.gather(x, 1, mask_sort_index1[:, H * W // 4 * 3:].unsqueeze(-1).repeat(1, 1, C)) # [N 784 64] [N 196 128]
seq1 = torch.cat([self.f1(p1.permute(0, 2, 1).reshape(B, C, token1, -1)).squeeze(-1), # [N 64 4 196] # linear(196,1) # [N 128 4 49] # linear(49,1) 次要
self.f2(p2.permute(0, 2, 1).reshape(B, C, token2, -1)).squeeze(-1), # [N 64 28 56] # linear(56,1) # [N 128 28 14] # linear(14,1)
self.f3(p3.permute(0, 2, 1).reshape(B, C, token3, -1)).squeeze(-1)], dim=-1).permute(0,2,1) # B N C # [N 64 28 28] # linear(28,1) # [N 128 28 7] # linear(7,1) 最重要
# seq1 [N 60 64] # seq1 [N 60 128]
x_ = x.view(B, H, W, C).permute(0, 2, 1, 3).reshape(B, H * W, C) # [N 3136 64] [N 784 128]
p1_ = torch.gather(x_, 1, mask_sort_index2[:, :H * W // 4].unsqueeze(-1).repeat(1, 1, C)) # B, N//4, C [N 784 64] [N 196 128]
p2_ = torch.gather(x_, 1, mask_sort_index2[:, H * W // 4:H * W // 4 * 3].unsqueeze(-1).repeat(1, 1, C)) # [N 1568 64] [N 392 128]
p3_ = torch.gather(x_, 1, mask_sort_index2[:, H * W // 4 * 3:].unsqueeze(-1).repeat(1, 1, C)) # [N 784 64] [N 196 128]
seq2 = torch.cat([self.f1(p1_.permute(0, 2, 1).reshape(B, C, token1, -1)).squeeze(-1),
self.f2(p2_.permute(0, 2, 1).reshape(B, C, token2, -1)).squeeze(-1),
self.f3(p3_.permute(0, 2, 1).reshape(B, C, token3, -1)).squeeze(-1)], dim=-1).permute(0,2,1) # B N C # seq2 [N 60 64] seq2 [N 60 128]
elif self.sr_ratio==2:
p1 = torch.gather(x, 1, mask_sort_index1[:, :H * W // 2].unsqueeze(-1).repeat(1, 1, C)) # B, N//4, C [N 98 256]
p2 = torch.gather(x, 1, mask_sort_index1[:, H * W // 2:].unsqueeze(-1).repeat(1, 1, C)) # [N 98 256]
seq1 = torch.cat([self.f1(p1.permute(0, 2, 1).reshape(B, C, token1, -1)).squeeze(-1), # [N 256 49 2] # linear(2,1)
self.f2(p2.permute(0, 2, 1).reshape(B, C, token2, -1)).squeeze(-1)], dim=-1).permute(0, 2, 1) # B N C # [N 256 98 1] # linear(1,1)
# seq1 [N 147 256]
x_ = x.view(B, H, W, C).permute(0, 2, 1, 3).reshape(B, H * W, C)
p1_ = torch.gather(x_, 1, mask_sort_index2[:, :H * W // 2].unsqueeze(-1).repeat(1, 1, C)) # B, N//4, C [N 98 256]
p2_ = torch.gather(x_, 1, mask_sort_index2[:, H * W // 2:].unsqueeze(-1).repeat(1, 1, C)) # [N 98 256]
seq2 = torch.cat([self.f1(p1_.permute(0, 2, 1).reshape(B, C, token1, -1)).squeeze(-1),
self.f2(p2_.permute(0, 2, 1).reshape(B, C, token2, -1)).squeeze(-1)], dim=-1).permute(0, 2, 1) # B N C
# seq2 [N 147 256]
kv1 = self.kv1(seq1).reshape(B, -1, 2, self.num_heads // 2, C // self.num_heads).permute(2, 0, 3, 1, 4) # kv B heads N C # [2 N 1 60 32] # [2 N 2 60 32] # [2 N 4 147 32]
kv2 = self.kv2(seq2).reshape(B, -1, 2, self.num_heads // 2, C // self.num_heads).permute(2, 0, 3, 1, 4) # [2 N 1 60 32] # [2 N 2 60 32] # [2 N 4 147 32]
kv = torch.cat([kv1, kv2], dim=2) # [2 N 2 60 32] # [2 N 4 60 32] # [2 N 8 147 32]
k, v = kv[0], kv[1] # [N 2 60 32] # [N 4 60 32] # [N 8 147 32]
attn = (q @ k.transpose(-2, -1)) * self.scale # [N 2 3136 60] # [N 4 784 60] [N 8 196 147]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # [N 3136 64] # [N 784 128] [N 196 256]
x = self.proj(x+lepe)
x = self.proj_drop(x)
mask=None
2 つの部分を結合します。
class Attention(nn.Module):
def __init__(self, dim, mask, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.sr_ratio=sr_ratio
if sr_ratio>1:
if mask:
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv1 = nn.Linear(dim, dim, bias=qkv_bias)
self.kv2 = nn.Linear(dim, dim, bias=qkv_bias)
if self.sr_ratio==8:
f1, f2, f3 = 14*14, 56, 28
elif self.sr_ratio==4:
f1, f2, f3 = 49, 14, 7
elif self.sr_ratio==2:
f1, f2, f3 = 2, 1, None
self.f1 = nn.Linear(f1, 1)
self.f2 = nn.Linear(f2, 1)
if f3 is not None:
self.f3 = nn.Linear(f3, 1)
else:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.act = nn.GELU()
self.q1 = nn.Linear(dim, dim//2, bias=qkv_bias)
self.kv1 = nn.Linear(dim, dim, bias=qkv_bias)
self.q2 = nn.Linear(dim, dim // 2, bias=qkv_bias)
self.kv2 = nn.Linear(dim, dim, bias=qkv_bias)
else:
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.lepe_linear = nn.Linear(dim, dim)
self.lepe_conv = local_conv(dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.linear = linear
self.apply(self._init_weights)
def forward(self, x, H, W, mask):
B, N, C = x.shape
lepe = self.lepe_conv(
self.lepe_linear(x).transpose(1, 2).view(B, C, H, W)).view(B, C, -1).transpose(-1, -2) # [N 3136 64] #[N 784 128] #[N 196 256] #[N 49 512]
if self.sr_ratio > 1:
if mask is None:
# global
q1 = self.q1(x).reshape(B, N, self.num_heads//2, C // self.num_heads).permute(0, 2, 1, 3) # [N 1 3136 32] # [N 2 784 32] # [N 4 196 32]
x_ = x.permute(0, 2, 1).reshape(B, C, H, W) # [N 64 56 56] # [N 128 28 28] # [N 256 14 14]
x_1 = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) # [N 49 64] sr --> conv2d(64, 64) # [N 49 128] sr --> conv2d(128, 128) # [N 49 256] sr --> conv2d(256, 256)
x_1 = self.act(self.norm(x_1)) # [N 49 64] # [N 49 128] # [N 49 256]
kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4) # [2 N 1 49 32] # [2 N 2 49 32] # [2 N 4 49 32]
k1, v1 = kv1[0], kv1[1] #B head N C [N 1 49 32] [N 2 49 32] [N 4 49 32]
attn1 = (q1 @ k1.transpose(-2, -1)) * self.scale #B head Nq Nkv [N 1 3136 49] [N 2 784 49] [N 4 196 49]
attn1 = attn1.softmax(dim=-1)
attn1 = self.attn_drop(attn1)
x1 = (attn1 @ v1).transpose(1, 2).reshape(B, N, C//2) # [N 3136 32] [N 784 64] [N 196 128]
global_mask_value = torch.mean(attn1.detach().mean(1), dim=1) # B Nk #max ? mean ? # [N 49]
global_mask_value = F.interpolate(global_mask_value.view(B,1,H//self.sr_ratio,W//self.sr_ratio),
(H, W), mode='nearest')[:, 0] # [N 56 56] [N 28 28] [N 14 14]
# local
q2 = self.q2(x).reshape(B, N, self.num_heads // 2, C // self.num_heads).permute(0, 2, 1, 3) #B head N C # [N 1 3136 32] [N 2 784 32] [N 4 196 32]
kv2 = self.kv2(x_.reshape(B, C, -1).permute(0, 2, 1)).reshape(B, -1, 2, self.num_heads // 2,
C // self.num_heads).permute(2, 0, 3, 1, 4)# [2 N 1 3136 32] [2 N 2 784 32] [2 N 4 196 32]
k2, v2 = kv2[0], kv2[1] # [N 1 3136 32] [N 2 784 32] [N 4 196 32]
q_window = 7
window_size= 7
q2, k2, v2 = window_partition(q2, q_window, H, W), window_partition(k2, window_size, H, W), \
window_partition(v2, window_size, H, W) # [N*64,49,32] [N*32,49,32] [N*16 49 32]
attn2 = (q2 @ k2.transpose(-2, -1)) * self.scale # [N*64 49 49] [N*32 49 49] [N*16 49 49]
# (B*numheads*num_windows, window_size*window_size, window_size*window_size)
attn2 = attn2.softmax(dim=-1)
attn2 = self.attn_drop(attn2)
x2 = (attn2 @ v2) # B*numheads*num_windows, window_size*window_size, C .transpose(1, 2).reshape(B, N, C) # [N*64 49 32] [N*32 49 32] [N*16 49 32]
x2 = window_reverse(x2, q_window, H, W, self.num_heads // 2) # [N 3136 32] [N 784 64] [N 196 128]
local_mask_value = torch.mean(attn2.detach().view(B, self.num_heads//2, H//window_size*W//window_size, window_size*window_size, window_size*window_size).mean(1), dim=2) #[N 64 49]
local_mask_value = local_mask_value.view(B, H // window_size, W // window_size, window_size, window_size) # [N 8 8 7 7]
local_mask_value=local_mask_value.permute(0, 1, 3, 2, 4).contiguous().view(B, H, W) # [N 56 56] [N 28 28] [N 14 14]
# mask B H W
x = torch.cat([x1, x2], dim=-1) # [N 3136 64] [N 784 128] [N 196 256]
x = self.proj(x+lepe) # linear(64,64) # linear(128,128) # linear(256,256)
x = self.proj_drop(x)
# cal mask
mask = local_mask_value+global_mask_value # [N 56 56] [N 28 28] [N 14 14]
mask_1 = mask.view(B, H * W) # [N 3136] [N 784] [N 196]
mask_2 = mask.permute(0, 2, 1).reshape(B, H * W) # [N 3136] [N 784] [N 196]
mask = [mask_1, mask_2]
else:
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # [N 2 3136 32] [N 4 784 32] [N 8 196 32]
# mask [local_mask global_mask] local_mask [value index] value [B, H, W]
# use mask to fuse
mask_1, mask_2 = mask # [[N 3136],[N 3136]] [[N 784],[N 784]] [[N 196],[N 196]]
mask_sort1, mask_sort_index1 = torch.sort(mask_1, dim=1)
mask_sort2, mask_sort_index2 = torch.sort(mask_2, dim=1)
if self.sr_ratio == 8:
token1, token2, token3 = H * W // (14 * 14), H * W // 56, H * W // 28 # [16 56 112]
token1, token2, token3 = token1 // 4, token2 // 2, token3 // 4 # [4 28 28]
elif self.sr_ratio == 4:
token1, token2, token3 = H * W // 49, H * W // 14, H * W // 7 # [16 56 112]
token1, token2, token3 = token1 // 4, token2 // 2, token3 // 4 # [4 28 28]
elif self.sr_ratio == 2:
token1, token2 = H * W // 2, H * W // 1 # [98 196]
token1, token2 = token1 // 2, token2 // 2 # [49 98]
if self.sr_ratio==4 or self.sr_ratio==8:
p1 = torch.gather(x, 1, mask_sort_index1[:, :H * W // 4].unsqueeze(-1).repeat(1, 1, C)) # B, N//4, C # [N 784 64] 根据mask中的index对x[:, :H * W // 4(784)]进行重新排序 [N 196 128]
p2 = torch.gather(x, 1, mask_sort_index1[:, H * W // 4:H * W // 4 * 3].unsqueeze(-1).repeat(1, 1, C)) # [N 1568 64] [N 392 128]
p3 = torch.gather(x, 1, mask_sort_index1[:, H * W // 4 * 3:].unsqueeze(-1).repeat(1, 1, C)) # [N 784 64] [N 196 128]
seq1 = torch.cat([self.f1(p1.permute(0, 2, 1).reshape(B, C, token1, -1)).squeeze(-1), # [N 64 4 196] # linear(196,1) # [N 128 4 49] # linear(49,1) 次要
self.f2(p2.permute(0, 2, 1).reshape(B, C, token2, -1)).squeeze(-1), # [N 64 28 56] # linear(56,1) # [N 128 28 14] # linear(14,1)
self.f3(p3.permute(0, 2, 1).reshape(B, C, token3, -1)).squeeze(-1)], dim=-1).permute(0,2,1) # B N C # [N 64 28 28] # linear(28,1) # [N 128 28 7] # linear(7,1) 最重要
# seq1 [N 60 64] # seq1 [N 60 128]
x_ = x.view(B, H, W, C).permute(0, 2, 1, 3).reshape(B, H * W, C) # [N 3136 64] [N 784 128]
p1_ = torch.gather(x_, 1, mask_sort_index2[:, :H * W // 4].unsqueeze(-1).repeat(1, 1, C)) # B, N//4, C [N 784 64] [N 196 128]
p2_ = torch.gather(x_, 1, mask_sort_index2[:, H * W // 4:H * W // 4 * 3].unsqueeze(-1).repeat(1, 1, C)) # [N 1568 64] [N 392 128]
p3_ = torch.gather(x_, 1, mask_sort_index2[:, H * W // 4 * 3:].unsqueeze(-1).repeat(1, 1, C)) # [N 784 64] [N 196 128]
seq2 = torch.cat([self.f1(p1_.permute(0, 2, 1).reshape(B, C, token1, -1)).squeeze(-1),
self.f2(p2_.permute(0, 2, 1).reshape(B, C, token2, -1)).squeeze(-1),
self.f3(p3_.permute(0, 2, 1).reshape(B, C, token3, -1)).squeeze(-1)], dim=-1).permute(0,2,1) # B N C # seq2 [N 60 64] seq2 [N 60 128]
elif self.sr_ratio==2:
p1 = torch.gather(x, 1, mask_sort_index1[:, :H * W // 2].unsqueeze(-1).repeat(1, 1, C)) # B, N//4, C [N 98 256]
p2 = torch.gather(x, 1, mask_sort_index1[:, H * W // 2:].unsqueeze(-1).repeat(1, 1, C)) # [N 98 256]
seq1 = torch.cat([self.f1(p1.permute(0, 2, 1).reshape(B, C, token1, -1)).squeeze(-1), # [N 256 49 2] # linear(2,1)
self.f2(p2.permute(0, 2, 1).reshape(B, C, token2, -1)).squeeze(-1)], dim=-1).permute(0, 2, 1) # B N C # [N 256 98 1] # linear(1,1)
# seq1 [N 147 256]
x_ = x.view(B, H, W, C).permute(0, 2, 1, 3).reshape(B, H * W, C)
p1_ = torch.gather(x_, 1, mask_sort_index2[:, :H * W // 2].unsqueeze(-1).repeat(1, 1, C)) # B, N//4, C [N 98 256]
p2_ = torch.gather(x_, 1, mask_sort_index2[:, H * W // 2:].unsqueeze(-1).repeat(1, 1, C)) # [N 98 256]
seq2 = torch.cat([self.f1(p1_.permute(0, 2, 1).reshape(B, C, token1, -1)).squeeze(-1),
self.f2(p2_.permute(0, 2, 1).reshape(B, C, token2, -1)).squeeze(-1)], dim=-1).permute(0, 2, 1) # B N C
# seq2 [N 147 256]
kv1 = self.kv1(seq1).reshape(B, -1, 2, self.num_heads // 2, C // self.num_heads).permute(2, 0, 3, 1, 4) # kv B heads N C # [2 N 1 60 32] # [2 N 2 60 32] # [2 N 4 147 32]
kv2 = self.kv2(seq2).reshape(B, -1, 2, self.num_heads // 2, C // self.num_heads).permute(2, 0, 3, 1, 4) # [2 N 1 60 32] # [2 N 2 60 32] # [2 N 4 147 32]
kv = torch.cat([kv1, kv2], dim=2) # [2 N 2 60 32] # [2 N 4 60 32] # [2 N 8 147 32]
k, v = kv[0], kv[1] # [N 2 60 32] # [N 4 60 32] # [N 8 147 32]
attn = (q @ k.transpose(-2, -1)) * self.scale # [N 2 3136 60] # [N 4 784 60] [N 8 196 147]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # [N 3136 64] # [N 784 128] [N 196 256]
x = self.proj(x+lepe)
x = self.proj_drop(x)
mask=None
else:
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # [N 16 49 32]
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) #[2 N 16 49 32]
k, v = kv[0], kv[1] # [N 16 49 32]
attn = (q @ k.transpose(-2, -1)) * self.scale # [N 16 49 49]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # [N 49 512]
x = self.proj(x+lepe) # linear(512,512)
x = self.proj_drop(x)
mask=None
return x, mask
スプリットウィンドウとSwinトランスは同じものです。
分類におけるパフォーマンス:
同様のパラメータの下で、SG-Former は競合他社よりもパフォーマンスが大幅に優れています。具体的には、ベースモデルは Swin-B 1.6 よりも優れています。以前の最先端の CSWin と比較して、SG-Former-S/M/B はそれぞれ +0.4、+0.3、+0.4 のパフォーマンス向上を達成しました。
筆者もさまざまなタスクでそのパフォーマンスを測定しましたが、すべてが現在の最先端モデルよりも優れていました。