Cswin は、上図で十字型の局所注意の使用を提案しました。VIT モデルにおける局所自己注意受容野の成長がさらに制限されるという問題を解決するために、Cswin は局所強化位置も提案しました。 Swin を超えるコーディングモジュール、複数のタスクに効果的SOTA (当時の SOTA は SG フォーマーに追い越されました。興味のある方は見てくださいSG 元)。
モデルの全体的な構造は上に示したとおりで、トークン埋め込みレイヤー と 4 で構成されます。 ステージブロックは積み重ねられ、各ステージブロックはで接続されます。 /a >暗さ 設計と同様に、各ダウンサンプリング後のR50 をダウンサンプリングするために使用されます。一般的なfeaturemap レイヤーは、conv
研究動機:
- ベース世界的な注目トランスフォーマーはうまく機能しますが、計算上特徴マップ サイズの複雑さと二乗(H==W ケース)比例。
- 地元の注目に基づくトランスフォーマーは、それぞれを制限します。 トークンの受容野の相互作用は受容野の成長を遅らせ、多数のブロックを積み重ねる必要があります< /span> グローバルな自己注意を達成します。
解決:
- 提案十字型ウィンドウのセルフ アテンションアテンション ヘッドをグループ化し、水平方向と垂直方向を並行して計算するメカニズム自己注意 を使用すると、より少ない計算量でより良い結果を得ることができます。
- ローカル位置情報をより適切に処理し、任意形状の入力をサポートできる、ローカル拡張位置エンコーディング (LePE) が提案されます。
1.1 畳み込みトークンの埋め込み
埋め込みに畳み込みを使用する: 計算量を減らすために、この記事では 7x7 畳み込みカーネルとストライド 4 の畳み込みを直接使用して入力を直接埋め込み、最後の次元を Layernorm します。
self.stage1_conv_embed = nn.Sequential(
nn.Conv2d(in_chans, embed_dim, 7, 4, 2),
Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4),
nn.LayerNorm(embed_dim)
)
1.2 十字型ウィンドウのセルフアテンション
具体的には、元の特徴マップが であると仮定すると、横方向の自己注意を計算するために、まず に分割されます。 >水平バーのデータ (実際のコードは最初に垂直方向に処理されます)。 は水平バーの幅です。これら 4 つの異なるステージで異なる値を採用した実験結果は、値のセット [1、2、7、7] が速度と精度のより良いバランスを達成することを示しています。
各ストリップ フィーチャについて、Transformer を使用してそのフィーチャを取得し、最後にこれらを変換します このヘッドの形状は、これらのフィーチャをつなぎ合わせることで得られます。それが番目の頭に属すると仮定すると、横方向の自己注意の計算方法は次のようになります。
垂直方向のセルフ アテンションの計算方法は、 の幅の垂直バーを使用することを除けば、V-Attendance と H-Attendance の計算方法は似ています。
最終的に、このブロックの出力は次のように表されます。
CSWin セルフアテンションの計算複雑性分析:
ハイレゾ入力の場合、初期ではHとWがCより大きく、後期ではCより小さいため、swは初期では小さく、後期では大きくなります。つまり、swを調整することで、後の段階で各トークンの注目領域を効果的に拡大することができます。 224×224入力の中間特徴マップのサイズをswで割り切れるようにするため、4段階のswはデフォルトで1、2、7、7に設定されています。
def img2windows(img, H_sp, W_sp):
"""
img: B C H W
"""
B, C, H, W = img.shape
img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) # [N*56*1 56 32] [N*56*1 56 32] / [N*14*1 56 64] [N*14*1 56 64] / [N*2*1 98 128] [N*2*1 98 128] / [N*1*1 49 512]
return img_perm
def windows2img(img_splits_hw, H_sp, W_sp, H, W):
"""
img_splits_hw: B' H W C
"""
B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) # [N*56*1 56 32]->[N 1 56 56 1 32] [N*56*1 56 32]->[N 56 1 1 56 32] / [N*14*1 56 64]->[N 1 14 28 2 64] [N*14*1 56 64]->[N 14 1 2 28 64] / [N*2*1 98 128]->[N 1 2 14 7 128] [N*2*1 98 128]->[N 2 1 7 14 128] / [N*1*1 49 512]->[N 1 1 7 7 512]
img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # [N 56 56 32] [N 28 28 64] [N 14 14 128] [N 7 7 512]
return img
class LePEAttention(nn.Module):
def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0.,
qk_scale=None):
super().__init__()
self.dim = dim
self.dim_out = dim_out or dim
self.resolution = resolution
self.split_size = split_size
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
if idx == -1:
H_sp, W_sp = self.resolution, self.resolution
elif idx == 0:
H_sp, W_sp = self.resolution, self.split_size
elif idx == 1:
W_sp, H_sp = self.resolution, self.split_size
else:
print("ERROR MODE", idx)
exit(0)
self.H_sp = H_sp
self.W_sp = W_sp
stride = 1
self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
self.attn_drop = nn.Dropout(attn_drop)
def im2cswin(self, x):
B, N, C = x.shape
H = W = int(np.sqrt(N))
x = x.transpose(-2, -1).contiguous().view(B, C, H, W) # [B, N, C] -> [B, C, N] -> [B, C, H, W]
x = img2windows(x, self.H_sp, self.W_sp) # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]
x = x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1,
3).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
return x
def get_lepe(self, x, func):
B, N, C = x.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
H = W = int(np.sqrt(N))
x = x.transpose(-2, -1).contiguous().view(B, C, H, W) # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]
H_sp, W_sp = self.H_sp, self.W_sp
x = x.view(B, C, H // H_sp, H_sp, W // W_sp,
W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,
W_sp) ### B', C, H', W' # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]
lepe = func(
x) ### B', C, H', W' # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]
lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,
2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,
2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
return x, lepe
def forward(self, qkv):
"""
x: B L C
"""
q, k, v = qkv[0], qkv[1], qkv[2] # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
### Img2Window
H = W = self.resolution # 56 28 14 7
B, L, C = q.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
assert L == H * W, "flatten img_tokens has wrong size"
q = self.im2cswin(q) # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
k = self.im2cswin(k) # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
v, lepe = self.get_lepe(v, self.get_v)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N
attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
attn = self.attn_drop(attn)
x = (attn @ v) + lepe
x = x.transpose(1, 2).reshape(-1, self.H_sp * self.W_sp,
C) # B head N N @ B head N C # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]
### Window2Img
x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C) # B H' W' C
return x # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
コード部分は実際には Swin に似ており、Swin のウィンドウの仕組みを理解し、ヘッド グループ化を追加すれば、基本的に論文の考え方をすぐに理解できます。
1.3 ローカル拡張位置エンコーディング(LePE)
Transformer は入力順序に依存しないため、位置エンコーディングを追加する必要があります。上図の左側はViTモデルのPEです。絶対位置エンコーディングまたは条件付き位置エンコーディングを使用します。埋め込み時にトークンとともにトランスフォーマーに入力されるだけです。中央はSwin、CrossFormerなどのモデルのPEです。相対位置エンコーディング偏差を使用する導入により、トークン グラフの重みがアテンションと一緒に計算され、APE よりも優れた柔軟性と効果をもたらします。
この記事で提案されている LePE は RPE より直接的です。位置情報を線形投影に適用します。また、RPE は頭部の形でバイアスを導入するのに対し、LepE はチャネルごとのバイアスであることにも注意してください。位置情報の埋め込みとして機能する可能性がさらに高まる可能性があります。つまり、位置コードは値ベクトルに直接追加されます。位置コードが であるとします。位置コード と位置コードを加算することによって追加されます。 は乗算によって完成します。次に、 追加された位置エンコーディングと、セルフ アテンションによって重み付けされた ユニットがショートカットを通じて加算されます。式は次のとおりです。
著者はここで、入力要素の場合、その近くの要素が最も重要な位置情報を提供するという仮定に基づいています。したがって、 V に対して深さの畳み込みを実行し、それをソフトマックスの後の結果に追加します。式は次のとおりです。
このようにして、LePE は、任意の入力解像度を入力として受け取るダウンストリーム タスクに適しています。
def get_lepe(self, x, func):
# func -> self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim)
B, N, C = x.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
H = W = int(np.sqrt(N))
x = x.transpose(-2, -1).contiguous().view(B, C, H, W) # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]
H_sp, W_sp = self.H_sp, self.W_sp
x = x.view(B, C, H // H_sp, H_sp, W // W_sp,
W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,
W_sp) ### B', C, H', W' # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]
lepe = func(
x) ### B', C, H', W' # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]
lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,
2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,
2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
return x, lepe
1.4 CSWin トランスフォーマー ブロック
CSWin Transformer Block の構造は図のようになりますが、最大の特徴は 2 つのショートカットを追加し、LN を使用して正規化していることです。
ネットワーク構造の構成:
は、 番目の Transformer ブロックまたは各ステージの畳み込み層の出力です。
CSwinのブロックはLayerNormと十字ウィンドウ自注目を行ってショートカットを接続する部分と、LayerNormとMLPを行う部分の2つに分かれており、SwinやTwinsに比べてブロックの計算量が大幅に削減されています。 swin と Twins には 2 つのアテンション + 2 つの MLP が 1 つのブロックにスタックされています)。
class CSWinBlock(nn.Module):
def __init__(self, dim, reso, num_heads,
split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
last_stage=False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.patches_resolution = reso
self.split_size = split_size
self.mlp_ratio = mlp_ratio
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm1 = norm_layer(dim)
if self.patches_resolution == split_size:
last_stage = True
if last_stage:
self.branch_num = 1
else:
self.branch_num = 2
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(drop)
if last_stage:
self.attns = nn.ModuleList([
LePEAttention(
dim, resolution=self.patches_resolution, idx = -1,
split_size=split_size, num_heads=num_heads, dim_out=dim,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
for i in range(self.branch_num)])
else:
self.attns = nn.ModuleList([
LePEAttention(
dim//2, resolution=self.patches_resolution, idx = i,
split_size=split_size, num_heads=num_heads//2, dim_out=dim//2,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
for i in range(self.branch_num)])
mlp_hidden_dim = int(dim * mlp_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop)
self.norm2 = norm_layer(dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H = W = self.patches_resolution # 56
B, L, C = x.shape # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]
assert L == H * W, "flatten img_tokens has wrong size"
img = self.norm1(x)
qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # [3 N 3136 64] [3 N 784 128] [3 N 196 256] [3 N 49 512]
if self.branch_num == 2:
x1 = self.attns[0](qkv[:,:,:,:C//2]) # qkv[3 N 3136 32]->x1[N 3136 32] qkv[3 N 784 128]->x1[N 784 64] qkv[3 N 196 256]->x1[N 196 128]
x2 = self.attns[1](qkv[:,:,:,C//2:]) # qkv[3 N 3136 32]->x2[N 3136 32] qkv[3 N 784 128]->x1[N 784 64] qkv[3 N 196 256]->x1[N 196 128]
attened_x = torch.cat([x1,x2], dim=2)
else:
attened_x = self.attns[0](qkv) # [3 N 49 512]->[N 49 512]
attened_x = self.proj(attened_x)
x = x + self.drop_path(attened_x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]
cswin は、同様のネットワーク パラメーターと計算量を持つモデルにおいて、分類タスクやさまざまな下流タスクで SOTA を達成しました。
検出:
分割: