Part1 Vision Transformer
1 Network structure
The ViT model is not only suitable for the NLP field, but also can achieve good results in the CV field.
In the original paper, the author compared three models, one is ViT, the "pure" Transformer model; one is the ResNet network; the other is the Hybrid model, which is a model that mixes traditional CNN and Transformer. It is finally found that when the number of iterations is large, the accuracy of the ViT model will exceed that of the hybrid model.
The ViT (Vision Transformer) model architecture is as follows:
The model first divides the picture into multiple patches, each patch size is 16*16; then each patch is input to the Embedding layer , and each patch can get a vector called token; then add a The token used for classification; then add Position Embedding for marking the position to each token; then input these tokens with added position information into the Transformer Encoder , and get the final classification result through MLPHead .
1.1 Linear Projection of Flattened Patches(Embedding层)
It can be realized directly through a convolutional layer. Input the token sequence, that is, the two-dimensional matrix [num_token, token_dim], and then splicing the tokens [class] token and superimposing Position Embedding. The splicing can be done by cat operation, and the superposition is directly related. Just add it.
After experiments, if you do not use Position Embedding, the accuracy rate will drop significantly, but what kind of Position Embedding is used has little effect on the accuracy rate, and the difference in position encoding is not important, so the default one-dimensional one with fewer parameters is used in the source code. location code.
The similarity between the finally learned position codes is as follows, and each row and each column has a high similarity:
1.2 Transformer Encoder
The layer structure and MLP structure are as follows:
Here, the Embedded Patches are subjected to Layer Norm, and then passed to the multi-head attention, then Dropout and Layer Norm are performed, and finally MLP is performed to obtain the Encoder Block, and then the Encoder Block is stacked L times.
1.3 MLPHead (final layer structure for classification)
When training ImageNet21K or larger datasets, it is composed of Linear+tanh activation function +Linear. When migrating to ImageNet1K or your own dataset, there is only one Linear.
1.4 Various types of ViT
There are three types, namely Base, Large and Huge, and the specifications are as follows:
- Layers: The number of times the Encoder Block is repeatedly stacked in the Transform Encoder
- Hidden Size: the vector length dim of each token after passing through the Embedding layer
- MLP Size: The number of first fully connected nodes of the MLP module, which is 4 times the Hidden Size
- Heads: the number of heads in the multi-head attention
2 Building a network based on Pytorch
The code comes from the official implementation
Learning link: ViT
Code link: (colab) ViT
# Vision Transformer
"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
# 随机深度方法
def drop_path(x,drop_prob:float=0.,training:bool=False):
if drop_prob==0. or not training:
return x
keep_prob = 1-drop_prob
shape = (x.shape[0],)+(1,)*(x.ndim-1) # work with diff dim tensors
random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob)*random_tensor
return output
class DropPath(nn.Module):
def __init__(self,drop_prob=None):
super(DropPath,self).__init__()
self.drop_prob = drop_prob
def forward(self,x):
return drop_path(x,self.drop_prob,self.training)
# Patch Embedding
class PatchEmbed(nn.Module):
def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
super().__init__()
img_size = (img_size,img_size)
patch_size = (patch_size,patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0]//patch_size[0],img_size[1]//patch_size[1])
self.num_patches = self.grid_size[0]*self.grid_size[1]
self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B,C,H,W = x.shape
assert H==self.img_size[0] and W==self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# flatten:[B,C,H,W]->[B,C,HW]
# transpose:[B,C,HW]->[B,HW,C]
x = self.proj(x).flatten(2).transpose(1,2)
x = self.norm(x)
return x
class Attention(nn.Module):
def __init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop_ratio=0.,proj_drop_ratio=0.):
super(Attention,self).__init__()
self.num_heads = num_heads
head_dim = dim//num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim,dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self, x):
B,N,C = x.shape
# 调整维度的位置,方便运算
qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
q,k,v = qkv[0],qkv[1],qkv[2]
# 矩阵乘法
attn = ([email protected](-2,-1))*self.scale # norm
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn@v).transpose(1,2).reshape(B,N,C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features,hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features,out_features)
self.drop = nn.Dropout(drop)
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(Block,self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio>0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim*mlp_ratio)
self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop_ratio)
def forward(self,x):
x = x+self.drop_path(self.attn(self.norm1(x)))
x = x+self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
def __init__(self,img_size=224,patch_size=16,in_c=3,num_classes=1000,
embed_dim=768,depth=12,num_heads=12,mlp_ratio=4.0,qkv_bias=True,
qk_scale=None,representation_size=None,distilled=False,drop_ratio=0.,
attn_drop_ratio=0.,drop_path_ratio=0.,embed_layer=PatchEmbed,
norm_layer=None,act_layer=None):
super(VisionTransformer,self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim=embed_dim
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm,eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1,1,embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
self.blocks = nn.Sequential(*[
Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,
qk_scale=qk_scale,drop_ratio=drop_ratio,attn_drop_ratio=attn_drop_ratio,
drop_path_ratio=dpr[i],norm_layer=norm_layer,act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc",nn.Linear(embed_dim,representation_size)),
("act",nn.Tanh())
]))
else:
self.has_logits = False
self.pre_logits = nn.Identity()
self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim,self.num_classes) if num_classes>0 else nn.Identity()
nn.init.trunc_normal_(self.pos_embed,std=0.02)
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token,std=0.02)
nn.init.trunc_normal_(self.cls_token,std=0.02)
self.apply(_init_vit_weights)
def forward_features(self, x):
# [B,C,H,W]->[B,num_patches,embed_dim]
x = self.patch_embed(x) # [B,196,768]
# [1,1,768]->[B,1,768]
cls_token = self.cls_token.expand(x.shape[0],-1,-1)
if self.dist_token is None:
x = torch.cat((cls_token,x),dim=1) # [B,197,768]
else:
x = torch.cat((cls_token,self.dist_token.expand(x.shape[0],-1,-1),x),dim=1)
x = self.pos_drop(x+self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:,0])
else:
return x[:,0], x[:,1]
def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]),self.head_dist(x[1])
if self.training and not torch.jit.is_scripting():
return x,x_dist
else:
return (x+x_dist)/2
else:
x = self.head(x)
return x
def _init_vit_weights(m):
if isinstance(m,nn.Linear):
nn.init.trunc_normal_(m.weight,std=.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m,nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def vit_base_patch16_224(num_classes:int=1000):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=None,
num_classes=num_classes)
return model
def vit_base_patch16_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes)
return model
def vit_base_patch32_224(num_classes:int=1000):
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=None,
num_classes=num_classes)
return model
def vit_base_patch32_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes)
return model
def vit_large_patch16_224(num_classes:int=1000):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=None,
num_classes=num_classes)
return model
def vit_large_patch16_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=1024 if has_logits else None,
num_classes=num_classes)
return model
def vit_large_patch32_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=1024 if has_logits else None,
num_classes=num_classes)
return model
def vit_huge_patch14_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
representation_size=1280 if has_logits else None,
num_classes=num_classes)
return model
Part2 Swin Transformer
1 Network structure
1.1 Overall Framework
Compared with ViT, Swin Transformer is more hierarchical. As the number of layers deepens, the downsampling intensity continues to increase, and it uses non-overlapping windows to separate the feature map, and performs MLP multi-head self-attention on each window. calculation, thereby greatly reducing the amount of calculation.
The overall network framework of Swin Transformer is as follows:
For a three-channel image, first perform the Patch Partition operation, and then down-sample through 4 different stages. Each stage of down-sampling will double, and every time it increases by 2 times, the number of channels will also double correspondingly, except for Stage1. Except for Linear Embedding, the heads of other Stages are Patch Merging. The Patch Partition operation here is to first divide the image with a 4*4 window, and then flatten it; the Linear Embedding layer plays the role of adjusting the dimension, and performs Layer Norm processing on each channel; both structures can be used This is achieved by building a convolutional layer.
1.2 Patch Merging
The principle of Patch Merging is as follows. It performs a downsampling operation, which halves the length and width of the feature map and doubles the channel:
1.3 W-MSA
W-MSA is Windows Multi-head Self-Attention. Compared with the previous multi-head self-attention module, it uses non-overlapping windows to separate the feature map, and each window performs multi-head self-attention calculations to reduce the amount of calculation. effect, but at the same time, it will also prevent information interaction before the window, making the receptive field smaller.
The calculation amount of the two is as follows, h and w respectively represent the height and width of the feature map, c represents the depth of the feature map, and m represents the size of each window.
1.4 SW-MSA
SW-MSA is Shifted Window Multi-head Self-Attention. The schematic diagram is as follows. On the basis of W-MSA, it has carried out a certain offset, thus realizing the information interaction between different windows:
1.5 Relative Position Bias (relative position offset)
The formula involved is as follows, where B is the relative position offset:
The schematic diagram of relative position offset is as follows:
1.6 Specific configuration parameters
2 Building a network based on Pytorch
# Swin Transformer
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
- https://arxiv.org/pdf/2103.14030
Code/weights from https://github.com/microsoft/Swin-Transformer
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional
def drop_path_f(x,drop_prob:float=0.,training:bool=False):
if drop_prob==0. or not training:
return x
keep_prob = 1-drop_prob
shape = (x.shape[0],)+(1,)*(x.ndim-1) # work with diff dim
random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob)*random_tensor
return output
class DropPath(nn.Module):
def __init__(self,drop_prob=None):
super(DropPath,self).__init__()
self.drop_prob = drop_prob
def forward(self,x):
return drop_path_f(x,self.drop_prob,self.training)
def window_partition(x,window_size:int):
# 将feature map按照window_size划分成一个个没有重叠的window
B,H,W,C = x.shape
x = x.view(B,H//window_size,window_size,W//window_size,window_size,C)
windows = x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)
return windows
def window_reverse(windows,window_size:int,H:int,W:int):
# 将一个个window还原成一个feature map
B = int(windows.shape[0]/(H*W/window_size/window_size))
x = windows.view(B,H//window_size,W//window_size,window_size,window_size,-1)
x = x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)
return x
class PatchEmbed(nn.Module):
def __init__(self,patch_size=4,in_c=3,embed_dim=96,norm_layer=None):
super().__init__()
patch_size = (patch_size,patch_size)
self.patch_size = patch_size
self.in_chans = in_c
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_,_,H,W = x.shape
# padding
pad_input = (H%self.patch_size[0]!=0) or (W%self.patch_size[1]!=0)
if pad_input:
x = F.pad(x,(0,self.patch_size[1]-W%self.patch_size[1],
0,self.patch_size[0]-H%self.patch_size[0],0,0))
# 下采样patch_size倍
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1,2)
x = self.norm(x)
return x,H,W
class PatchMerging(nn.Module):
def __init__(self,dim,norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4*dim,2*dim,bias=False)
self.norm = norm_layer(4*dim)
def forward(self,x,H,W):
B,L,C = x.shape
assert L==H*W,"input feature has wrong size"
x = x.view(B,H,W,C)
# padding
pad_input = (H%2==1) or (W%2==1)
if pad_input:
x = F.pad(x,(0,0,0,W%2,0,H%2))
x0 = x[:,0::2,0::2,:] # [B,H/2,W/2,C]
x1 = x[:,1::2,0::2,:] # [B,H/2,W/2,C]
x2 = x[:,0::2,1::2,:] # [B,H/2,W/2,C]
x3 = x[:,1::2,1::2,:] # [B,H/2,W/2,C]
x = torch.cat([x0,x1,x2,x3],-1) # [B,H/2,W/2,4*C]
x = x.view(B,-1,4*C) # [B,H/2*W/2,4*C]
x = self.norm(x)
x = self.reduction(x) # [B,H/2*W/2,2*C]
return x
class Mlp(nn.Module):
def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features,hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features,out_features)
self.drop2 = nn.Dropout(drop)
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class WindowAttention(nn.Module):
def __init__(self,dim,window_size,num_heads,qkv_bias=True,attn_drop=0.,proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim//num_heads
self.scale = head_dim**-0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1),num_heads))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h,coords_w],indexing="ij"))
coords_flatten = torch.flatten(coords,1)
relative_coords = coords_flatten[:,:,None]-coords_flatten[:,None,:]
relative_coords = relative_coords.permute(1,2,0).contiguous()
relative_coords[:,:,0] += self.window_size[0]-1
relative_coords[:,:,1] += self.window_size[1]-1
relative_coords[:,:,0] *= 2*self.window_size[1]-1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index",relative_position_index)
self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim,dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.trunc_normal_(self.relative_position_bias_table,std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self,x,mask:Optional[torch.Tensor]=None):
B_,N,C = x.shape
qkv = self.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
q,k,v = qkv.unbind(0)
q = q*self.scale
attn = ([email protected](-2,-1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0]*self.window_size[1],self.window_size[0]*self.window_size[1],-1)
relative_position_bias = relative_position_bias.permute(2,0,1).contiguous()
attn = attn+relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_//nW,nW,self.num_heads,N,N)+mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1,self.num_heads,N,N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn@v).transpose(1,2).reshape(B_,N,C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
def __init__(self,dim,num_heads,window_size=7,shift_size=0,
mlp_ratio=4.,qkv_bias=True,drop=0.,attn_drop=0.,drop_path=0.,
act_layer=nn.GELU,norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0<=self.shift_size<self.window_size,"shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,window_size=(self.window_size,self.window_size),num_heads=num_heads,
qkv_bias=qkv_bias,attn_drop=attn_drop,proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim*mlp_ratio)
self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop)
def forward(self,x,attn_mask):
H,W = self.H, self.W
B,L,C = x.shape
assert L==H*W,"input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B,H,W,C)
# 把feature map给pad到window size的整数倍
pad_l = pad_t = 0
pad_r = (self.window_size-W%self.window_size)%self.window_size
pad_b = (self.window_size-H%self.window_size)%self.window_size
x = F.pad(x,(0,0,pad_l,pad_r,pad_t,pad_b))
_,Hp,Wp,_ = x.shape
if self.shift_size>0:
shifted_x = torch.roll(x,shifts=(-self.shift_size,-self.shift_size),dims=(1,2))
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x,self.window_size)
x_windows = x_windows.view(-1,self.window_size*self.window_size,C)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows,mask=attn_mask)
attn_windows = attn_windows.view(-1,self.window_size,self.window_size,C)
shifted_x = window_reverse(attn_windows,self.window_size,Hp,Wp)
if self.shift_size>0:
x = torch.roll(shifted_x,shifts=(self.shift_size,self.shift_size),dims=(1,2))
else:
x = shifted_x
if pad_r>0 or pad_b>0:
# 把前面pad的数据移除掉
x = x[:,:H,:W,:].contiguous()
x = x.view(B,H*W,C)
# FFN
x = shortcut+self.drop_path(x)
x = x+self.drop_path(self.mlp(self.norm2(x)))
return x
class BasicLayer(nn.Module):
def __init__(self,dim,depth,num_heads,window_size,mlp_ratio=4.,
qkv_bias=True,drop=0.,attn_drop=0.,drop_path=0.,
norm_layer=nn.LayerNorm,downsample=None,use_checkpoint=False):
super().__init__()
self.dim = dim
self.depth = depth
self.window_size = window_size
self.use_checkpoint = use_checkpoint
self.shift_size = window_size//2
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i%2==0) else self.shift_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path,list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
if downsample is not None:
self.downsample = downsample(dim=dim,norm_layer=norm_layer)
else:
self.downsample = None
def create_mask(self,x,H,W):
# 保证Hp和Wp是window_size的整数倍
Hp = int(np.ceil(H/self.window_size))*self.window_size
Wp = int(np.ceil(W/self.window_size))*self.window_size
# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1,Hp,Wp,1),device=x.device)
h_slices = (slice(0,-self.window_size),
slice(-self.window_size,-self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0,-self.window_size),
slice(-self.window_size,-self.shift_size),
slice(-self.shift_size,None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:,h,w,:] = cnt
cnt += 1
mask_windows = window_partition(img_mask,self.window_size)
mask_windows = mask_windows.view(-1,self.window_size*self.window_size)
attn_mask = mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))
return attn_mask
def forward(self,x,H,W):
attn_mask = self.create_mask(x,H,W)
for blk in self.blocks:
blk.H,blk.W = H,W
if not torch.jit.is_scripting() and self.use_checkpoint:
x = checkpoint.checkpoint(blk,x,attn_mask)
else:
x = blk(x,attn_mask)
if self.downsample is not None:
x = self.downsample(x,H,W)
H,W = (H+1)//2,(W+1)//2
return x,H,W
class SwinTransformer(nn.Module):
def __init__(self,patch_size=4,in_chans=3,num_classes=1000,
embed_dim=96,depths=(2,2,6,2),num_heads=(3,6,12,24),
window_size=7,mlp_ratio=4.,qkv_bias=True,
drop_rate=0.,attn_drop_rate=0.,drop_path_rate=0.1,
norm_layer=nn.LayerNorm,patch_norm=True,
use_checkpoint=False,**kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
# stage4输出特征矩阵的channels
self.num_features = int(embed_dim*2**(self.num_layers-1))
self.mlp_ratio = mlp_ratio
# 分割成不重叠的patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,in_c=in_chans,embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layers = BasicLayer(dim=int(embed_dim*2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer+1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer<self.num_layers-1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layers)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self,m):
if isinstance(m,nn.Linear):
nn.init.trunc_normal_(m.weight,std=.02)
if isinstance(m,nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias,0)
elif isinstance(m,nn.LayerNorm):
nn.init.constant_(m.bias,0)
nn.init.constant_(m.weight,1.0)
def forward(self,x):
x,H,W = self.patch_embed(x)
x = self.pos_drop(x)
for layer in self.layers:
x,H,W = layer(x,H,W)
x = self.norm(x)
x = self.avgpool(x.transpose(1,2))
x = torch.flatten(x,1)
x = self.head(x)
return x
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2,2,6,2),
num_heads=(3,6,12,24),num_classes=num_classes,**kwargs)
return model
def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2,2,18,2),
num_heads=(3,6,12,24),num_classes=num_classes,**kwargs)
return model
def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2,2,18,2),
num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
return model
def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2,2,18,2),
num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
return model
def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2,2,18,2),
num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
return model
def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2,2,18,2),
num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
return model
def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=192,depths=(2,2,18,2),
num_heads=(6,12,24,48),num_classes=num_classes,**kwargs)
return model
def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=192,depths=(2,2,18,2),
num_heads=(6,12,24,48),num_classes=num_classes,**kwargs)
return model
Part3 ConvNeXt
1 Network structure
1.1 Structure used
Marco design
①stage ratio
Adjust the stacking times of ResNet50 from (3,4,6,3) to (3,3,9,3), which is consistent with Swin-T, and the effect has been significantly improved
②“patchify” stem
Replacing the stem (the original downsampling module) with a convolutional layer with a convolution kernel size of 4 and a stride of 4, the accuracy rate has been slightly improved, and the FLOPs have also been slightly reduced
ResNeXt
Compared with ResNet, ResNeXt has a better balance between FLOPs and accuracy. Here, the author also uses DW convolution. When the width of the input feature is increased, the accuracy rate has been greatly improved, and FLOPs have also increased.
Inverted bottleneck
The author believes that the MLP module in the Transformer block is very similar to the inverted residual module with two thick ends and thin middle, so the Bottleneck block is replaced with an inverted residual module, the accuracy has been slightly improved, and the FLOPs have also decreased significantly.
Large kerner size
Move the DW convolution up, before it was 1*1 convolution->DW convolution->1*1 convolution, now it is DW convolution->1*1 convolution->1*1 convolution, and DW Convolution kernel size changed from 3*3 to 7*7
Various layer-wise Micro designs
Replace ReLU with GELU, and reduce the number of activation functions used, reduce the number of times BN is used, replace BN with LN, speed up convergence and reduce overfitting, and finally use a separate downsampling layer
1.2 Network Effects
Compared with the Swin Transformer of the same scale, ConvNeXt has a higher accuracy rate and increases the number of images per second by about 40%.
1.3 Various versions
Where C represents the channel of each input feature layer, and B represents the number of repetitions of the block of each stage
2 Building a network based on Pytorch
# ConvNeXt
"""
original code from facebook research:
https://github.com/facebookresearch/ConvNeXt
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def drop_path(x,drop_prob:float=0.,training:bool=False):
if drop_prob==0. or not training:
return x
keep_prob = 1-drop_prob
shape = (x.shape[0],)+(1,)*(x.ndim-1) # work with diff dim
random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
random_tensor.floor_()
output = x.div(keep_prob)*random_tensor
return output
class DropPath(nn.Module):
def __init__(self,drop_prob=None):
super(DropPath,self).__init__()
self.drop_prob = drop_prob
def forward(self,x):
return drop_path(x,self.drop_prob,self.training)
class LayerNorm(nn.Module):
def __init__(self,normalized_shape,eps=1e-6,data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape),requires_grad=True)
self.bias = nn.Parameter(torch.zeros(normalized_shape),requires_grad=True)
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last","channels_first"]:
raise ValueError(f"not support data format '{self.data_format}'")
self.normalized_shape = (normalized_shape,)
def forward(self,x:torch.Tensor)->torch.Tensor:
if self.data_format=="channels_last":
return F.layer_norm(x,self.normalized_shape,self.weight,self.bias,self.eps)
elif self.data_format=="channels_first":
# [batch_size,channels,height,width]
mean = x.mean(1,keepdim=True)
var = (x-mean).pow(2).mean(1,keepdim=True)
x = (x-mean)/torch.sqrt(var+self.eps)
x = self.weight[:,None,None]*x+self.bias[:,None,None]
return x
class Block(nn.Module):
def __init__(self,dim,drop_rate=0.,layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim,dim,kernel_size=7,padding=3,groups=dim) # DW卷积
self.norm = LayerNorm(dim,eps=1e-6,data_format="channels_last")
self.pwconv1 = nn.Linear(dim,4*dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4*dim,dim)
self.gamma = nn.Parameter(layer_scale_init_value*torch.ones((dim,)),
requires_grad=True) if layer_scale_init_value>0 else None
self.drop_path = DropPath(drop_rate) if drop_rate>0. else nn.Identity()
def forward(self,x:torch.Tensor)->torch.Tensor:
shortcut = x
x = self.dwconv(x)
x = x.permute(0,2,3,1) # [N,C,H,W]->[N,H,W,C]
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma*x
x = x.permute(0,3,1,2) # [N,H,W,C]->[N,C,H,W]
x = shortcut+self.drop_path(x)
return x
class ConvNeXt(nn.Module):
def __init__(self,in_chans:int=3,num_classes:int=1000,depths:list=None,
dims:list=None,drop_path_rate:float=0.,
layer_scale_init_value:float=1e-6,head_init_scale:float=1.):
super().__init__()
self.downsample_layers = nn.ModuleList()
stem = nn.Sequential(nn.Conv2d(in_chans,dims[0],kernel_size=4,stride=4),
LayerNorm(dims[0],eps=1e-6,data_format="channels_first"))
self.downsample_layers.append(stem)
# 对应stage2-stage4前的3个downsample
for i in range(3):
downsample_layer = nn.Sequential(LayerNorm(dims[i],eps=1e-6,data_format="channels_first"),
nn.Conv2d(dims[i],dims[i+1],kernel_size=2,stride=2))
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]
cur = 0
# 构建每个stage中堆叠的block
for i in range(4):
stage = nn.Sequential(*[Block(dim=dims[i],drop_rate=dp_rates[cur+j],layer_scale_init_value=layer_scale_init_value)
for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1],eps=1e-6)
self.head = nn.Linear(dims[-1],num_classes)
self.apply(self._init_weights)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def _init_weights(self,m):
if isinstance(m,(nn.Conv2d,nn.Linear)):
nn.init.trunc_normal_(m.weight,std=0.2)
nn.init.constant_(m.bias,0)
def forward_features(self,x:torch.Tensor)->torch.Tensor:
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean([-2,-1]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.head(x)
return x
def convnext_tiny(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
model = ConvNeXt(depths=[3,3,9,3],dims=[96,192,384,768],num_classes=num_classes)
return model
def convnext_small(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
model = ConvNeXt(depths=[3,3,27,3],dims=[96,192,384,768],num_classes=num_classes)
return model
def convnext_base(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
# https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
model = ConvNeXt(depths=[3,3,27,3],dims=[128,256,512,1024],num_classes=num_classes)
return model
def convnext_large(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth
# https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
model = ConvNeXt(depths=[3,3,27,3],dims=[192,384,768,1536],=num_classes)
return model
def convnext_xlarge(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
model = ConvNeXt(depths=[3,3,27,3],dims=[256,512,1024,2048],=num_classes)
return model
Part4 personal experience
Natural language processing and computer vision have some similarities, so Transformer is currently a popular research direction, and there are many places that can be optimized and transformed. However, CNN is more mature in development and has more mature collocations. structure, if necessary, multiple structures can be combined to achieve better results.