Thesis Portal: Swin Transformer: Hierarchical Vision Transformer using Shifted WindowsPrevious
article: ViT model - pytorch implementation
Features of Swin Transformer:
Compared with ViT:
①Adopt gradually increasing downsampling multiples to obtain hierarchical feature maps (hierarchical feature maps), which is convenient for detection and segmentation tasks;
②Introduce W-MSA (Windows Multi-Head Self-Attention) and SW- MSA (Shifted Windows Multi-Head Self-Attention), which reduces the amount of calculation.
W-MSA or SW-MSA:
W-MSA divides the feature map into windows, and then divides the patch and calculates the attention in each window, which can reduce the amount of calculation, but at the same time makes it impossible to exchange information between different windows; therefore , the
author It also proposes SW-MSA , that is, the offset W-MSA. The specific method is to shift the grid that originally divided Windows to the right and down the length of window_size//2, and then stitch the small pieces of Windows into a whole piece of Window through translation splicing. , so as to ensure the same number of Windows as W-MSA, and generate a mask at the same time, set the tokens (patches) of the original non-adjacent area to -100, and when calculating Attention, it is the same as attention ( QKT QK^TQKT ) is added so that the attention of non-adjacent areas approaches 0 after softmax to avoid the interference of non-adjacent areas. After Attention, the spliced Windows are split, and the reverse translation is restored to the original feature map.
Relative Position Bias:
In the Attention calculation, a bias item B is introduced , and B is taken from a trainable matrix (relative positionon bias matrix), and the index is the relative position index of each token (patch) (after some transformations).
The structure of Swin Transformer:
① Patch Partition : Divide the input image into 4x4 windows and stack them on the Channel channel. Conv2d is used in the code to implement;
② Linear Embedding : Flatten the H and W dimensions;
(in the code, Patch Partition and Linear Embedding are passed through a Patch Embedding Realization)
③ Swin Transformer Block : Appears in pairs, the whole structure is the same as the Transformer Block in ViT, except that the MSA is replaced by W-MSA and SW-MSA, the odd-numbered Block uses W-MSA, and the even-numbered Block uses SW -MSA (the two are used alternately);
④ Patch Merging : Downsampling method, similar to focus, each time Patch Merging first halves the height H and width W, quadruples the channel C, and then halves C through a Linear, that is Finally, C is twice the original;
Stage: Linear Embedding/Patch Merging + L * Swin Transformer Block
Swin Transformer models of different sizes:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path_f(x, self.drop_prob, self.training)
class PatchEmbedding(nn.Module): # Patch Partition + Linear Embedding
def __init__(self, patch_size=4, in_channels=3, emb_dim=96):
super(PatchEmbedding, self).__init__()
self.conv = nn.Conv2d(in_channels, emb_dim, patch_size, patch_size) # 4x4卷积实现Patch Partition
def forward(self, x):
# (B,C,H,W)
x = self.conv(x)
_, _, H, W = x.shape
x = rearrange(x, "B C H W -> B (H W) C") # Linear Embedding
return x, H, W
class MLP(nn.Module): # MLP
def __init__(self, in_dim, hidden_dim=None, drop_ratio=0.):
super(MLP, self).__init__()
if hidden_dim is None:
hidden_dim = in_dim * 4 # linear的hidden_dims默认为in_dims的4倍
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, in_dim)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(drop_ratio)
def forward(self, x):
# Linear + GELU + Dropout + Linear + Dropout
x = self.fc1(x)
x = self.gelu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class WindowMultiHeadSelfAttention(nn.Module): # W-MSA / SW-MSA
def __init__(self, dim, window_size, num_heads,
attn_drop_ratio=0., proj_drop_ratio=0.):
super(WindowMultiHeadSelfAttention, self).__init__()
self.window_size = window_size
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
# 创建Relative position bias matrix,其参数可训练,根据Relative position index取其中的值作为B
self.relative_position_bias_matrix = nn.Parameter(torch.zeros((2 * window_size - 1) ** 2, num_heads))
# 使用register_buffer,使得relative_position_index可以随model.state_dict()保存,并可以随model.cuda()加载至GPU
self.register_buffer("relative_position_index", self._get_relative_position_index())
def _get_relative_position_index(self): # 创建Relative position index
coords = torch.flatten(
torch.stack(
torch.meshgrid([torch.arange(self.window_size), torch.arange(self.window_size)], indexing="ij"), dim=0
), 1
)
relative_coords = coords[:, :, None] - coords[:, None, :]
relative_coords += self.window_size - 1
relative_coords[0, :, :] *= 2 * self.window_size - 1
relative_position_index = relative_coords.sum(0)
return relative_position_index.view(-1)
def forward(self, x, mask=None):
qkv = self.qkv(x)
qkv = rearrange(qkv, "B P (C H d) -> C B H P d", C=3, H=self.num_heads, d=self.head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
k = rearrange(k, "B H P d -> B H d P")
# Attention(Q, K, V ) = softmax(QKT/dk)V (T表示转置)
attn = torch.matmul(q, k) * self.head_dim ** -0.5 # QKT/dk
bias = self.relative_position_bias_matrix[self.relative_position_index]
bias = rearrange(bias, "(P1 P2) H -> 1 H P1 P2", P1=self.window_size ** 2, P2=self.window_size ** 2)
attn += bias # QKT/dk + B
if mask is not None:
# 如果mask不为None,对attn进行加和,使得在原图上不相邻的token对应的attn-100,经过softmax后趋近于0
attn = rearrange(attn, "(B NW) H P1 P2 -> B NW H P1 P2", NW=mask.shape[0])
mask = rearrange(mask, "NW P1 P2 -> 1 NW 1 P1 P2")
attn += mask
attn = rearrange(attn, "B NW H P1 P2 -> (B NW) H P1 P2")
attn = F.softmax(attn) # softmax(QKT/dk + B)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v) # softmax(QKT/dk + B)V
x = rearrange(x, "B H P d -> B P (H d)")
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module): # Swin Transformer Block
def __init__(self, dim, num_heads, window_size=7, shift=True,
attn_drop_ratio=0., proj_drop_ratio=0., drop_path_ratio=0.):
super(SwinTransformerBlock, self).__init__()
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = window_size // 2 if shift else 0. # 不进行shift时,shift_size取0
self.layernorm1 = nn.LayerNorm(dim)
self.attn = WindowMultiHeadSelfAttention(dim, self.window_size, self.num_heads,
attn_drop_ratio=attn_drop_ratio,
proj_drop_ratio=proj_drop_ratio)
self.droppath = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.layernorm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim)
def _create_mask(self, H, W, device): # 创建mask
mask = torch.zeros((1, 1, H, W), device=device)
slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
count = 0
for h in slices:
for w in slices:
mask[:, :, h, w] = count
count += 1
mask = rearrange(mask, "1 1 (H Hs) (W Ws) -> (H W) (Hs Ws)", Hs=self.window_size, Ws=self.window_size)
attn_mask = mask.unsqueeze(1) - mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.)) # 在原图上不相邻的token,mask为-100.
attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.)) # 在原图上相邻的token,mask为0.
return attn_mask
def forward(self, input: tuple):
x, H, W = input
shortcut = x
x = self.layernorm1(x)
x = rearrange(x, "B (H W) C -> B C H W", H=H, W=W)
if self.shift_size > 0.: # 如果偏移量shift_size>0.,则对x进行偏移,同时创建对应的mask
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
mask = self._create_mask(H, W, device=x.device)
else:
mask = None
num_windows = (x.shape[2] // self.window_size, x.shape[3] // self.window_size)
# x = rearrange(x, "B C (H Hs) (W Ws) -> (B H W) C Hs Ws", Hs=self.window_size, Ws=self.window_size)
x = rearrange(x, "B C (H Hs) (W Ws) -> (B H W) (Hs Ws) C", Hs=self.window_size, Ws=self.window_size)
x = self.attn(x, mask)
# x = rearrange(x, "(B H W) C Hs Ws -> B C (H Hs) (W Ws)", Hs=self.window_size, Ws=self.window_size)
x = rearrange(x, "(B H W) (Hs Ws) C -> B C (H Hs) (W Ws)", H=num_windows[0], W=num_windows[1],
Hs=self.window_size, Ws=self.window_size)
if self.shift_size > 0.: # 如果偏移量shift_size>0.,则将偏移过的x调整回原来的位置
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3))
x = rearrange(x, "B C H W -> B (H W) C", H=H, W=W)
x = shortcut + self.droppath(x) # 残差连接
shortcut = x
x = self.layernorm2(x)
x = self.mlp(x)
x = shortcut + self.droppath(x) # 残差连接
return x, H, W
class PatchMerging(nn.Module): # Patch Merging
def __init__(self, dim):
super(PatchMerging, self).__init__()
self.layernorm = nn.LayerNorm(4 * dim)
self.linear = nn.Linear(4 * dim, 2 * dim, bias=False)
def forward(self, input: tuple):
# (B,L,C) --> (B,C,H,W) --> (B,4*C,H/2,W/2) --> (B,L/4,4*C) --> (B,L/4,2*C)
x, H, W = input
x = rearrange(x, "B (H W) C -> B C H W", H=H, W=W)
x = torch.cat([x[:, :, 0::2, 0::2], x[:, :, 1::2, 0::2], x[:, :, 0::2, 1::2], x[:, :, 1::2, 1::2]], dim=1)
_, _, H, W = x.shape
x = rearrange(x, "B C H W -> B (H W) C")
x = self.layernorm(x)
x = self.linear(x)
return x, H, W
class SwinHead(nn.Module): # Swin Head,分类任务的Head
def __init__(self, dim, num_classes):
super(SwinHead, self).__init__()
self.layernorm = nn.LayerNorm(dim)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.mlphead = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.layernorm(x)
x = rearrange(x, "B L C -> B C L")
x = self.avgpool(x)
return self.mlphead(x.squeeze())
class SwinTransformer(nn.Module): # Swin Transformer
def __init__(self, dims=(96, 192, 384, 768), num_blocks=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
num_classes=1000,
pos_drop_ratio=0., attn_drop_ratio=0., proj_drop_ratio=0., drop_path_ratio_max=0.1):
super(SwinTransformer, self).__init__()
self.patchembedding = PatchEmbedding(emb_dim=dims[0])
self.pos_drop = nn.Dropout(pos_drop_ratio)
# #drop path ratio从0递增至drop_path_ratio_max
drop_path_ratio = [i.item() for i in torch.linspace(0, drop_path_ratio_max, sum(num_blocks))]
self.blocks1 = nn.Sequential(
*[SwinTransformerBlock(dims[0], num_heads[0], shift=(i % 2 != 0),
attn_drop_ratio=attn_drop_ratio,
proj_drop_ratio=proj_drop_ratio,
drop_path_ratio=drop_path_ratio[i + sum(num_blocks[:0])])
for i in range(num_blocks[0])]
)
self.patchmerging2 = PatchMerging(dims[0])
self.blocks2 = nn.Sequential(
*[SwinTransformerBlock(dims[1], num_heads[1], shift=(i % 2 != 0),
attn_drop_ratio=attn_drop_ratio,
proj_drop_ratio=proj_drop_ratio,
drop_path_ratio=drop_path_ratio[i + sum(num_blocks[:1])])
for i in range(num_blocks[1])]
)
self.patchmerging3 = PatchMerging(dims[1])
self.blocks3 = nn.Sequential(
*[SwinTransformerBlock(dims[2], num_heads[2], shift=(i % 2 != 0),
attn_drop_ratio=attn_drop_ratio,
proj_drop_ratio=proj_drop_ratio,
drop_path_ratio=drop_path_ratio[i + sum(num_blocks[:2])])
for i in range(num_blocks[2])]
)
self.patchmerging4 = PatchMerging(dims[2])
self.blocks4 = nn.Sequential(
*[SwinTransformerBlock(dims[3], num_heads[3], shift=(i % 2 != 0),
attn_drop_ratio=attn_drop_ratio,
proj_drop_ratio=proj_drop_ratio,
drop_path_ratio=drop_path_ratio[i + sum(num_blocks[:3])])
for i in range(num_blocks[3])]
)
self.head = SwinHead(dims[-1], num_classes)
def forward(self, x):
# Patch Partition + Stage1
x, H, W = self.patchembedding(x)
x = self.pos_drop(x)
x, H, W = self.blocks1((x, H, W))
# Stage2
x, H, W = self.patchmerging2((x, H, W))
x, H, W = self.blocks2((x, H, W))
# Stage3
x, H, W = self.patchmerging3((x, H, W))
x, H, W = self.blocks3((x, H, W))
# Stage4
x, H, W = self.patchmerging4((x, H, W))
x, H, W = self.blocks4((x, H, W))
return self.head(x)
def Swin_T(num_classes=1000): # Swin Tiny
return SwinTransformer(dims=(96, 192, 384, 768),
num_blocks=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
num_classes=num_classes)
def Swin_S(num_classes=1000): # Swin Small
return SwinTransformer(dims=(96, 192, 384, 768),
num_blocks=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
num_classes=num_classes)
def Swin_B(num_classes=1000): # Swin Base
return SwinTransformer(dims=(128, 256, 512, 1024),
num_blocks=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes)
def Swin_L(num_classes=1000): # Swin Large
return SwinTransformer(dims=(192, 384, 768, 1536),
num_blocks=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
num_classes=num_classes)
if __name__ == "__main__":
cuda = True if torch.cuda.is_available() else False
images = torch.randn(8, 3, 224, 224)
swin_t = Swin_T()
swin_s = Swin_S()
swin_b = Swin_B()
swin_l = Swin_L()
if cuda:
images = images.cuda()
swin_t.cuda()
swin_s.cuda()
swin_b.cuda()
swin_l.cuda()
print(swin_t(images).shape)
print(swin_s(images).shape)
print(swin_b(images).shape)
print(swin_l(images).shape)