[OUC Deep Learning Introduction] Week 6 Learning Record: Vision Transformer & Swin Transformer & ConvNeXt

Part1 Vision Transformer

1 Network structure

The ViT model is not only suitable for the NLP field, but also can achieve good results in the CV field.

In the original paper, the author compared three models, one is ViT, the "pure" Transformer model; one is the ResNet network; the other is the Hybrid model, which is a model that mixes traditional CNN and Transformer. It is finally found that when the number of iterations is large, the accuracy of the ViT model will exceed that of the hybrid model.

The ViT (Vision Transformer) model architecture is as follows: 

The model first divides the picture into multiple patches, each patch size is 16*16; then each patch is input to the Embedding layer , and each patch can get a vector called token; then add a The token used for classification; then add Position Embedding for marking the position to each token; then input these tokens with added position information into the Transformer Encoder , and get the final classification result through MLPHead .

1.1 Linear Projection of Flattened Patches(Embedding层)

It can be realized directly through a convolutional layer. Input the token sequence, that is, the two-dimensional matrix [num_token, token_dim], and then splicing the tokens [class] token and superimposing Position Embedding. The splicing can be done by cat operation, and the superposition is directly related. Just add it.

After experiments, if you do not use Position Embedding, the accuracy rate will drop significantly, but what kind of Position Embedding is used has little effect on the accuracy rate, and the difference in position encoding is not important, so the default one-dimensional one with fewer parameters is used in the source code. location code.

The similarity between the finally learned position codes is as follows, and each row and each column has a high similarity:

1.2 Transformer Encoder

The layer structure and MLP structure are as follows:

Here, the Embedded Patches are subjected to Layer Norm, and then passed to the multi-head attention, then Dropout and Layer Norm are performed, and finally MLP is performed to obtain the Encoder Block, and then the Encoder Block is stacked L times.

1.3 MLPHead (final layer structure for classification)

When training ImageNet21K or larger datasets, it is composed of Linear+tanh activation function +Linear. When migrating to ImageNet1K or your own dataset, there is only one Linear.

1.4 Various types of ViT

There are three types, namely Base, Large and Huge, and the specifications are as follows:

  • Layers: The number of times the Encoder Block is repeatedly stacked in the Transform Encoder
  • Hidden Size: the vector length dim of each token after passing through the Embedding layer
  • MLP Size: The number of first fully connected nodes of the MLP module, which is 4 times the Hidden Size
  • Heads: the number of heads in the multi-head attention 

2 Building a network based on Pytorch

The code comes from the official implementation

Learning link: ViT

Code link: (colab) ViT

# Vision Transformer

"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""

from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


# 随机深度方法
def drop_path(x,drop_prob:float=0.,training:bool=False):
  if drop_prob==0. or not training:
    return x
  keep_prob = 1-drop_prob
  shape = (x.shape[0],)+(1,)*(x.ndim-1)  # work with diff dim tensors
  random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
  random_tensor.floor_()  # binarize
  output = x.div(keep_prob)*random_tensor
  return output

class DropPath(nn.Module):
  def __init__(self,drop_prob=None):
    super(DropPath,self).__init__()
    self.drop_prob = drop_prob

  def forward(self,x):
    return drop_path(x,self.drop_prob,self.training)


# Patch Embedding
class PatchEmbed(nn.Module):
  def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
    super().__init__()
    img_size = (img_size,img_size)
    patch_size = (patch_size,patch_size)
    self.img_size = img_size
    self.patch_size = patch_size
    self.grid_size = (img_size[0]//patch_size[0],img_size[1]//patch_size[1])
    self.num_patches = self.grid_size[0]*self.grid_size[1]

    self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

  def forward(self, x):
    B,C,H,W = x.shape
    assert H==self.img_size[0] and W==self.img_size[1], \
        f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

    # flatten:[B,C,H,W]->[B,C,HW]
    # transpose:[B,C,HW]->[B,HW,C]
    x = self.proj(x).flatten(2).transpose(1,2)
    x = self.norm(x)
    return x

class Attention(nn.Module):
  def __init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop_ratio=0.,proj_drop_ratio=0.):
    super(Attention,self).__init__()
    self.num_heads = num_heads
    head_dim = dim//num_heads
    self.scale = qk_scale or head_dim**-0.5
    self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_drop_ratio)
    self.proj = nn.Linear(dim,dim)
    self.proj_drop = nn.Dropout(proj_drop_ratio)

  def forward(self, x):
    B,N,C = x.shape

    # 调整维度的位置,方便运算
    qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
    q,k,v = qkv[0],qkv[1],qkv[2]

    # 矩阵乘法
    attn = ([email protected](-2,-1))*self.scale  # norm
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn@v).transpose(1,2).reshape(B,N,C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

class Mlp(nn.Module):
  def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features
    self.fc1 = nn.Linear(in_features,hidden_features)
    self.act = act_layer()
    self.fc2 = nn.Linear(hidden_features,out_features)
    self.drop = nn.Dropout(drop)

  def forward(self,x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop(x)
    x = self.fc2(x)
    x = self.drop(x)
    return x

class Block(nn.Module):
  def __init__(self,
        dim,
        num_heads,
        mlp_ratio=4.,
        qkv_bias=False,
        qk_scale=None,
        drop_ratio=0.,
        attn_drop_ratio=0.,
        drop_path_ratio=0.,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm):
    super(Block,self).__init__()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
                attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
    
    self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio>0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim*mlp_ratio)
    self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop_ratio)

  def forward(self,x):
    x = x+self.drop_path(self.attn(self.norm1(x)))
    x = x+self.drop_path(self.mlp(self.norm2(x)))
    return x

class VisionTransformer(nn.Module):
  def __init__(self,img_size=224,patch_size=16,in_c=3,num_classes=1000,
        embed_dim=768,depth=12,num_heads=12,mlp_ratio=4.0,qkv_bias=True,
        qk_scale=None,representation_size=None,distilled=False,drop_ratio=0.,
        attn_drop_ratio=0.,drop_path_ratio=0.,embed_layer=PatchEmbed,
        norm_layer=None,act_layer=None):
    super(VisionTransformer,self).__init__()
    self.num_classes = num_classes
    self.num_features = self.embed_dim=embed_dim
    self.num_tokens = 2 if distilled else 1
    norm_layer = norm_layer or partial(nn.LayerNorm,eps=1e-6)
    act_layer = act_layer or nn.GELU

    self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
    num_patches = self.patch_embed.num_patches

    self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
    self.dist_token = nn.Parameter(torch.zeros(1,1,embed_dim)) if distilled else None
    self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
    self.pos_drop = nn.Dropout(p=drop_ratio)

    dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
    self.blocks = nn.Sequential(*[
        Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,
          qk_scale=qk_scale,drop_ratio=drop_ratio,attn_drop_ratio=attn_drop_ratio,
          drop_path_ratio=dpr[i],norm_layer=norm_layer,act_layer=act_layer)
        for i in range(depth)
    ])
    self.norm = norm_layer(embed_dim)

    # Representation layer
    if representation_size and not distilled:
      self.has_logits = True
      self.num_features = representation_size
      self.pre_logits = nn.Sequential(OrderedDict([
          ("fc",nn.Linear(embed_dim,representation_size)),
          ("act",nn.Tanh())
      ]))
    else:
      self.has_logits = False
      self.pre_logits = nn.Identity()

    
    self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()
    self.head_dist = None
    if distilled:
      self.head_dist = nn.Linear(self.embed_dim,self.num_classes) if num_classes>0 else nn.Identity()

    
    nn.init.trunc_normal_(self.pos_embed,std=0.02)
    if self.dist_token is not None:
      nn.init.trunc_normal_(self.dist_token,std=0.02)

    nn.init.trunc_normal_(self.cls_token,std=0.02)
    self.apply(_init_vit_weights)

  def forward_features(self, x):
    # [B,C,H,W]->[B,num_patches,embed_dim]
    x = self.patch_embed(x)  # [B,196,768]
    # [1,1,768]->[B,1,768]
    cls_token = self.cls_token.expand(x.shape[0],-1,-1)
    if self.dist_token is None:
      x = torch.cat((cls_token,x),dim=1)  # [B,197,768]
    else:
      x = torch.cat((cls_token,self.dist_token.expand(x.shape[0],-1,-1),x),dim=1)

    x = self.pos_drop(x+self.pos_embed)
    x = self.blocks(x)
    x = self.norm(x)
    if self.dist_token is None:
      return self.pre_logits(x[:,0])
    else:
      return x[:,0], x[:,1]

  def forward(self, x):
    x = self.forward_features(x)
    if self.head_dist is not None:
      x, x_dist = self.head(x[0]),self.head_dist(x[1])
      if self.training and not torch.jit.is_scripting():
        return x,x_dist
      else:
        return (x+x_dist)/2
    else:
      x = self.head(x)
    return x

def _init_vit_weights(m):
  if isinstance(m,nn.Linear):
    nn.init.trunc_normal_(m.weight,std=.01)
    if m.bias is not None:
      nn.init.zeros_(m.bias)
  elif isinstance(m,nn.Conv2d):
    nn.init.kaiming_normal_(m.weight,mode="fan_out")
    if m.bias is not None:
      nn.init.zeros_(m.bias)
  elif isinstance(m,nn.LayerNorm):
    nn.init.zeros_(m.bias)
    nn.init.ones_(m.weight)

def vit_base_patch16_224(num_classes:int=1000):
  model = VisionTransformer(img_size=224,
                patch_size=16,
                embed_dim=768,
                depth=12,
                num_heads=12,
                representation_size=None,
                num_classes=num_classes)
  return model

def vit_base_patch16_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=16,
                embed_dim=768,
                depth=12,
                num_heads=12,
                representation_size=768 if has_logits else None,
                num_classes=num_classes)
  return model

def vit_base_patch32_224(num_classes:int=1000):
  model = VisionTransformer(img_size=224,
                patch_size=32,
                embed_dim=768,
                depth=12,
                num_heads=12,
                representation_size=None,
                num_classes=num_classes)
  return model


def vit_base_patch32_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=32,
                embed_dim=768,
                depth=12,
                num_heads=12,
                representation_size=768 if has_logits else None,
                num_classes=num_classes)
  return model

def vit_large_patch16_224(num_classes:int=1000):
  model = VisionTransformer(img_size=224,
                patch_size=16,
                embed_dim=1024,
                depth=24,
                num_heads=16,
                representation_size=None,
                num_classes=num_classes)
  return model


def vit_large_patch16_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=16,
                embed_dim=1024,
                depth=24,
                num_heads=16,
                representation_size=1024 if has_logits else None,
                num_classes=num_classes)
  return model

def vit_large_patch32_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=32,
                embed_dim=1024,
                depth=24,
                num_heads=16,
                representation_size=1024 if has_logits else None,
                num_classes=num_classes)
  return model


def vit_huge_patch14_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=14,
                embed_dim=1280,
                depth=32,
                num_heads=16,
                representation_size=1280 if has_logits else None,
                num_classes=num_classes)
  return model

Part2 Swin Transformer

1 Network structure

1.1 Overall Framework

 Compared with ViT, Swin Transformer is more hierarchical. As the number of layers deepens, the downsampling intensity continues to increase, and it uses non-overlapping windows to separate the feature map, and performs MLP multi-head self-attention on each window. calculation, thereby greatly reducing the amount of calculation.

The overall network framework of Swin Transformer is as follows:

For a three-channel image, first perform the Patch Partition operation, and then down-sample through 4 different stages. Each stage of down-sampling will double, and every time it increases by 2 times, the number of channels will also double correspondingly, except for Stage1. Except for Linear Embedding, the heads of other Stages are Patch Merging. The Patch Partition operation here is to first divide the image with a 4*4 window, and then flatten it; the Linear Embedding layer plays the role of adjusting the dimension, and performs Layer Norm processing on each channel; both structures can be used This is achieved by building a convolutional layer.

1.2 Patch Merging

The principle of Patch Merging is as follows. It performs a downsampling operation, which halves the length and width of the feature map and doubles the channel:

1.3 W-MSA

W-MSA is Windows Multi-head Self-Attention. Compared with the previous multi-head self-attention module, it uses non-overlapping windows to separate the feature map, and each window performs multi-head self-attention calculations to reduce the amount of calculation. effect, but at the same time, it will also prevent information interaction before the window, making the receptive field smaller.

The calculation amount of the two is as follows, h and w respectively represent the height and width of the feature map, c represents the depth of the feature map, and m represents the size of each window.

1.4 SW-MSA

SW-MSA is Shifted Window Multi-head Self-Attention. The schematic diagram is as follows. On the basis of W-MSA, it has carried out a certain offset, thus realizing the information interaction between different windows:

1.5 Relative Position Bias (relative position offset)

The formula involved is as follows, where B is the relative position offset:

 The schematic diagram of relative position offset is as follows:

 

 

 

1.6 Specific configuration parameters

2 Building a network based on Pytorch

# Swin Transformer

""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
    - https://arxiv.org/pdf/2103.14030

Code/weights from https://github.com/microsoft/Swin-Transformer

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional


def drop_path_f(x,drop_prob:float=0.,training:bool=False):
  if drop_prob==0. or not training:
    return x
  keep_prob = 1-drop_prob
  shape = (x.shape[0],)+(1,)*(x.ndim-1)  # work with diff dim
  random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
  random_tensor.floor_()  # binarize
  output = x.div(keep_prob)*random_tensor
  return output


class DropPath(nn.Module):
  def __init__(self,drop_prob=None):
    super(DropPath,self).__init__()
    self.drop_prob = drop_prob

  def forward(self,x):
    return drop_path_f(x,self.drop_prob,self.training)


def window_partition(x,window_size:int):
  # 将feature map按照window_size划分成一个个没有重叠的window
  B,H,W,C = x.shape
  x = x.view(B,H//window_size,window_size,W//window_size,window_size,C)
  windows = x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)
  return windows


def window_reverse(windows,window_size:int,H:int,W:int):
  # 将一个个window还原成一个feature map
  B = int(windows.shape[0]/(H*W/window_size/window_size))
  x = windows.view(B,H//window_size,W//window_size,window_size,window_size,-1)
  x = x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)
  return x


class PatchEmbed(nn.Module):
  def __init__(self,patch_size=4,in_c=3,embed_dim=96,norm_layer=None):
    super().__init__()
    patch_size = (patch_size,patch_size)
    self.patch_size = patch_size
    self.in_chans = in_c
    self.embed_dim = embed_dim
    self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

  def forward(self, x):
    _,_,H,W = x.shape
    # padding
    pad_input = (H%self.patch_size[0]!=0) or (W%self.patch_size[1]!=0)
    if pad_input:
      x = F.pad(x,(0,self.patch_size[1]-W%self.patch_size[1],
            0,self.patch_size[0]-H%self.patch_size[0],0,0))

    # 下采样patch_size倍
    x = self.proj(x)
    _, _, H, W = x.shape
    x = x.flatten(2).transpose(1,2)
    x = self.norm(x)
    return x,H,W


class PatchMerging(nn.Module):
  def __init__(self,dim,norm_layer=nn.LayerNorm):
    super().__init__()
    self.dim = dim
    self.reduction = nn.Linear(4*dim,2*dim,bias=False)
    self.norm = norm_layer(4*dim)

  def forward(self,x,H,W):
    B,L,C = x.shape
    assert L==H*W,"input feature has wrong size"

    x = x.view(B,H,W,C)

    # padding
    pad_input = (H%2==1) or (W%2==1)
    if pad_input:
      x = F.pad(x,(0,0,0,W%2,0,H%2))

    x0 = x[:,0::2,0::2,:]  # [B,H/2,W/2,C]
    x1 = x[:,1::2,0::2,:]  # [B,H/2,W/2,C]
    x2 = x[:,0::2,1::2,:]  # [B,H/2,W/2,C]
    x3 = x[:,1::2,1::2,:]  # [B,H/2,W/2,C]
    x = torch.cat([x0,x1,x2,x3],-1)  # [B,H/2,W/2,4*C]
    x = x.view(B,-1,4*C)  # [B,H/2*W/2,4*C]

    x = self.norm(x)
    x = self.reduction(x)  # [B,H/2*W/2,2*C]

    return x


class Mlp(nn.Module):
  def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features

    self.fc1 = nn.Linear(in_features,hidden_features)
    self.act = act_layer()
    self.drop1 = nn.Dropout(drop)
    self.fc2 = nn.Linear(hidden_features,out_features)
    self.drop2 = nn.Dropout(drop)

  def forward(self,x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop1(x)
    x = self.fc2(x)
    x = self.drop2(x)
    return x


class WindowAttention(nn.Module):
  def __init__(self,dim,window_size,num_heads,qkv_bias=True,attn_drop=0.,proj_drop=0.):
    super().__init__()
    self.dim = dim
    self.window_size = window_size
    self.num_heads = num_heads
    head_dim = dim//num_heads
    self.scale = head_dim**-0.5

    self.relative_position_bias_table = nn.Parameter(
        torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1),num_heads))

    coords_h = torch.arange(self.window_size[0])
    coords_w = torch.arange(self.window_size[1])
    coords = torch.stack(torch.meshgrid([coords_h,coords_w],indexing="ij"))
    coords_flatten = torch.flatten(coords,1)
    relative_coords = coords_flatten[:,:,None]-coords_flatten[:,None,:]
    relative_coords = relative_coords.permute(1,2,0).contiguous()
    relative_coords[:,:,0] += self.window_size[0]-1
    relative_coords[:,:,1] += self.window_size[1]-1
    relative_coords[:,:,0] *= 2*self.window_size[1]-1
    relative_position_index = relative_coords.sum(-1)
    self.register_buffer("relative_position_index",relative_position_index)

    self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(dim,dim)
    self.proj_drop = nn.Dropout(proj_drop)

    nn.init.trunc_normal_(self.relative_position_bias_table,std=.02)
    self.softmax = nn.Softmax(dim=-1)

  def forward(self,x,mask:Optional[torch.Tensor]=None):
    B_,N,C = x.shape
    qkv = self.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
    q,k,v = qkv.unbind(0)

    q = q*self.scale
    attn = ([email protected](-2,-1))

    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        self.window_size[0]*self.window_size[1],self.window_size[0]*self.window_size[1],-1)
    relative_position_bias = relative_position_bias.permute(2,0,1).contiguous()
    attn = attn+relative_position_bias.unsqueeze(0)

    if mask is not None:
      nW = mask.shape[0]
      attn = attn.view(B_//nW,nW,self.num_heads,N,N)+mask.unsqueeze(1).unsqueeze(0)
      attn = attn.view(-1,self.num_heads,N,N)
      attn = self.softmax(attn)
    else:
      attn = self.softmax(attn)

    attn = self.attn_drop(attn)

    x = (attn@v).transpose(1,2).reshape(B_,N,C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x


class SwinTransformerBlock(nn.Module):
  def __init__(self,dim,num_heads,window_size=7,shift_size=0,
      mlp_ratio=4.,qkv_bias=True,drop=0.,attn_drop=0.,drop_path=0.,
      act_layer=nn.GELU,norm_layer=nn.LayerNorm):
    super().__init__()
    self.dim = dim
    self.num_heads = num_heads
    self.window_size = window_size
    self.shift_size = shift_size
    self.mlp_ratio = mlp_ratio
    assert 0<=self.shift_size<self.window_size,"shift_size must in 0-window_size"

    self.norm1 = norm_layer(dim)
    self.attn = WindowAttention(
        dim,window_size=(self.window_size,self.window_size),num_heads=num_heads,
        qkv_bias=qkv_bias,attn_drop=attn_drop,proj_drop=drop)

    self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim*mlp_ratio)
    self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop)

  def forward(self,x,attn_mask):
    H,W = self.H, self.W
    B,L,C = x.shape
    assert L==H*W,"input feature has wrong size"

    shortcut = x
    x = self.norm1(x)
    x = x.view(B,H,W,C)

    # 把feature map给pad到window size的整数倍
    pad_l = pad_t = 0
    pad_r = (self.window_size-W%self.window_size)%self.window_size
    pad_b = (self.window_size-H%self.window_size)%self.window_size
    x = F.pad(x,(0,0,pad_l,pad_r,pad_t,pad_b))
    _,Hp,Wp,_ = x.shape

    if self.shift_size>0:
      shifted_x = torch.roll(x,shifts=(-self.shift_size,-self.shift_size),dims=(1,2))
    else:
      shifted_x = x
      attn_mask = None

    # partition windows
    x_windows = window_partition(shifted_x,self.window_size)
    x_windows = x_windows.view(-1,self.window_size*self.window_size,C)

    # W-MSA/SW-MSA
    attn_windows = self.attn(x_windows,mask=attn_mask)

    attn_windows = attn_windows.view(-1,self.window_size,self.window_size,C)
    shifted_x = window_reverse(attn_windows,self.window_size,Hp,Wp)

    if self.shift_size>0:
      x = torch.roll(shifted_x,shifts=(self.shift_size,self.shift_size),dims=(1,2))
    else:
      x = shifted_x

    if pad_r>0 or pad_b>0:
      # 把前面pad的数据移除掉
      x = x[:,:H,:W,:].contiguous()

    x = x.view(B,H*W,C)

    # FFN
    x = shortcut+self.drop_path(x)
    x = x+self.drop_path(self.mlp(self.norm2(x)))

    return x


class BasicLayer(nn.Module):
  def __init__(self,dim,depth,num_heads,window_size,mlp_ratio=4.,
      qkv_bias=True,drop=0.,attn_drop=0.,drop_path=0.,
      norm_layer=nn.LayerNorm,downsample=None,use_checkpoint=False):
    super().__init__()
    self.dim = dim
    self.depth = depth
    self.window_size = window_size
    self.use_checkpoint = use_checkpoint
    self.shift_size = window_size//2

    self.blocks = nn.ModuleList([
        SwinTransformerBlock(
            dim=dim,
            num_heads=num_heads,
            window_size=window_size,
            shift_size=0 if (i%2==0) else self.shift_size,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop=drop,
            attn_drop=attn_drop,
            drop_path=drop_path[i] if isinstance(drop_path,list) else drop_path,
            norm_layer=norm_layer)
        for i in range(depth)])

    if downsample is not None:
      self.downsample = downsample(dim=dim,norm_layer=norm_layer)
    else:
      self.downsample = None

  def create_mask(self,x,H,W):
    # 保证Hp和Wp是window_size的整数倍
    Hp = int(np.ceil(H/self.window_size))*self.window_size
    Wp = int(np.ceil(W/self.window_size))*self.window_size
    # 拥有和feature map一样的通道排列顺序,方便后续window_partition
    img_mask = torch.zeros((1,Hp,Wp,1),device=x.device)
    h_slices = (slice(0,-self.window_size),
          slice(-self.window_size,-self.shift_size),
          slice(-self.shift_size, None))
    w_slices = (slice(0,-self.window_size),
          slice(-self.window_size,-self.shift_size),
          slice(-self.shift_size,None))
    cnt = 0
    for h in h_slices:
      for w in w_slices:
        img_mask[:,h,w,:] = cnt
        cnt += 1

    mask_windows = window_partition(img_mask,self.window_size)
    mask_windows = mask_windows.view(-1,self.window_size*self.window_size)
    attn_mask = mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))
    return attn_mask

  def forward(self,x,H,W):
    attn_mask = self.create_mask(x,H,W)
    for blk in self.blocks:
      blk.H,blk.W = H,W
      if not torch.jit.is_scripting() and self.use_checkpoint:
        x = checkpoint.checkpoint(blk,x,attn_mask)
      else:
        x = blk(x,attn_mask)
    if self.downsample is not None:
      x = self.downsample(x,H,W)
      H,W = (H+1)//2,(W+1)//2

    return x,H,W


class SwinTransformer(nn.Module):
  def __init__(self,patch_size=4,in_chans=3,num_classes=1000,
      embed_dim=96,depths=(2,2,6,2),num_heads=(3,6,12,24),
      window_size=7,mlp_ratio=4.,qkv_bias=True,
      drop_rate=0.,attn_drop_rate=0.,drop_path_rate=0.1,
      norm_layer=nn.LayerNorm,patch_norm=True,
      use_checkpoint=False,**kwargs):
    super().__init__()

    self.num_classes = num_classes
    self.num_layers = len(depths)
    self.embed_dim = embed_dim
    self.patch_norm = patch_norm
    # stage4输出特征矩阵的channels
    self.num_features = int(embed_dim*2**(self.num_layers-1))
    self.mlp_ratio = mlp_ratio

    # 分割成不重叠的patches
    self.patch_embed = PatchEmbed(
      patch_size=patch_size,in_c=in_chans,embed_dim=embed_dim,
      norm_layer=norm_layer if self.patch_norm else None)
    self.pos_drop = nn.Dropout(p=drop_rate)

    dpr = [x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]

    # build layers
    self.layers = nn.ModuleList()
    for i_layer in range(self.num_layers):
      layers = BasicLayer(dim=int(embed_dim*2**i_layer),
                    depth=depths[i_layer],
                    num_heads=num_heads[i_layer],
                    window_size=window_size,
                    mlp_ratio=self.mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer+1])],
                    norm_layer=norm_layer,
                    downsample=PatchMerging if (i_layer<self.num_layers-1) else None,
                    use_checkpoint=use_checkpoint)
      self.layers.append(layers)

    self.norm = norm_layer(self.num_features)
    self.avgpool = nn.AdaptiveAvgPool1d(1)
    self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()

    self.apply(self._init_weights)

  def _init_weights(self,m):
    if isinstance(m,nn.Linear):
      nn.init.trunc_normal_(m.weight,std=.02)
      if isinstance(m,nn.Linear) and m.bias is not None:
        nn.init.constant_(m.bias,0)
    elif isinstance(m,nn.LayerNorm):
      nn.init.constant_(m.bias,0)
      nn.init.constant_(m.weight,1.0)

  def forward(self,x):
    x,H,W = self.patch_embed(x)
    x = self.pos_drop(x)

    for layer in self.layers:
      x,H,W = layer(x,H,W)

    x = self.norm(x)
    x = self.avgpool(x.transpose(1,2))
    x = torch.flatten(x,1)
    x = self.head(x)
    return x


def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
  # trained ImageNet-1K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2,2,6,2),
              num_heads=(3,6,12,24),num_classes=num_classes,**kwargs)
  return model


def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
  # trained ImageNet-1K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2,2,18,2),
              num_heads=(3,6,12,24),num_classes=num_classes,**kwargs)
  return model


def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
  # trained ImageNet-1K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2,2,18,2),
              num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
  return model


def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
  # trained ImageNet-1K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2,2,18,2),
              num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
  return model


def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
  # trained ImageNet-22K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2,2,18,2),
              num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
  return model


def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
  # trained ImageNet-22K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2,2,18,2),
              num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
  return model


def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
  # trained ImageNet-22K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=192,depths=(2,2,18,2),
              num_heads=(6,12,24,48),num_classes=num_classes,**kwargs)
  return model


def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
  # trained ImageNet-22K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=192,depths=(2,2,18,2),
              num_heads=(6,12,24,48),num_classes=num_classes,**kwargs)
  return model

Part3 ConvNeXt

1 Network structure

1.1 Structure used

Marco design

①stage ratio

Adjust the stacking times of ResNet50 from (3,4,6,3) to (3,3,9,3), which is consistent with Swin-T, and the effect has been significantly improved

②“patchify” stem

Replacing the stem (the original downsampling module) with a convolutional layer with a convolution kernel size of 4 and a stride of 4, the accuracy rate has been slightly improved, and the FLOPs have also been slightly reduced

ResNeXt

Compared with ResNet, ResNeXt has a better balance between FLOPs and accuracy. Here, the author also uses DW convolution. When the width of the input feature is increased, the accuracy rate has been greatly improved, and FLOPs have also increased.

Inverted bottleneck

The author believes that the MLP module in the Transformer block is very similar to the inverted residual module with two thick ends and thin middle, so the Bottleneck block is replaced with an inverted residual module, the accuracy has been slightly improved, and the FLOPs have also decreased significantly.

Large kerner size

Move the DW convolution up, before it was 1*1 convolution->DW convolution->1*1 convolution, now it is DW convolution->1*1 convolution->1*1 convolution, and DW Convolution kernel size changed from 3*3 to 7*7

Various layer-wise Micro designs

Replace ReLU with GELU, and reduce the number of activation functions used, reduce the number of times BN is used, replace BN with LN, speed up convergence and reduce overfitting, and finally use a separate downsampling layer

1.2 Network Effects

Compared with the Swin Transformer of the same scale, ConvNeXt has a higher accuracy rate and increases the number of images per second by about 40%.

1.3 Various versions

Where C represents the channel of each input feature layer, and B represents the number of repetitions of the block of each stage

2 Building a network based on Pytorch

# ConvNeXt

"""
original code from facebook research:
https://github.com/facebookresearch/ConvNeXt
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


def drop_path(x,drop_prob:float=0.,training:bool=False):
  if drop_prob==0. or not training:
    return x
  keep_prob = 1-drop_prob
  shape = (x.shape[0],)+(1,)*(x.ndim-1)  # work with diff dim
  random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
  random_tensor.floor_()
  output = x.div(keep_prob)*random_tensor
  return output


class DropPath(nn.Module):
  def __init__(self,drop_prob=None):
    super(DropPath,self).__init__()
    self.drop_prob = drop_prob

  def forward(self,x):
    return drop_path(x,self.drop_prob,self.training)


class LayerNorm(nn.Module):
  def __init__(self,normalized_shape,eps=1e-6,data_format="channels_last"):
    super().__init__()
    self.weight = nn.Parameter(torch.ones(normalized_shape),requires_grad=True)
    self.bias = nn.Parameter(torch.zeros(normalized_shape),requires_grad=True)
    self.eps = eps
    self.data_format = data_format
    if self.data_format not in ["channels_last","channels_first"]:
      raise ValueError(f"not support data format '{self.data_format}'")
    self.normalized_shape = (normalized_shape,)

  def forward(self,x:torch.Tensor)->torch.Tensor:
    if self.data_format=="channels_last":
      return F.layer_norm(x,self.normalized_shape,self.weight,self.bias,self.eps)
    elif self.data_format=="channels_first":
      # [batch_size,channels,height,width]
      mean = x.mean(1,keepdim=True)
      var = (x-mean).pow(2).mean(1,keepdim=True)
      x = (x-mean)/torch.sqrt(var+self.eps)
      x = self.weight[:,None,None]*x+self.bias[:,None,None]
      return x


class Block(nn.Module):
  def __init__(self,dim,drop_rate=0.,layer_scale_init_value=1e-6):
    super().__init__()
    self.dwconv = nn.Conv2d(dim,dim,kernel_size=7,padding=3,groups=dim)  # DW卷积
    self.norm = LayerNorm(dim,eps=1e-6,data_format="channels_last")
    self.pwconv1 = nn.Linear(dim,4*dim)
    self.act = nn.GELU()
    self.pwconv2 = nn.Linear(4*dim,dim)
    self.gamma = nn.Parameter(layer_scale_init_value*torch.ones((dim,)),
                  requires_grad=True) if layer_scale_init_value>0 else None
    self.drop_path = DropPath(drop_rate) if drop_rate>0. else nn.Identity()

  def forward(self,x:torch.Tensor)->torch.Tensor:
    shortcut = x
    x = self.dwconv(x)
    x = x.permute(0,2,3,1)  # [N,C,H,W]->[N,H,W,C]
    x = self.norm(x)
    x = self.pwconv1(x)
    x = self.act(x)
    x = self.pwconv2(x)
    if self.gamma is not None:
      x = self.gamma*x
    x = x.permute(0,3,1,2)  # [N,H,W,C]->[N,C,H,W]

    x = shortcut+self.drop_path(x)
    return x


class ConvNeXt(nn.Module):
  def __init__(self,in_chans:int=3,num_classes:int=1000,depths:list=None,
              dims:list=None,drop_path_rate:float=0.,
              layer_scale_init_value:float=1e-6,head_init_scale:float=1.):
    super().__init__()
    self.downsample_layers = nn.ModuleList()
    stem = nn.Sequential(nn.Conv2d(in_chans,dims[0],kernel_size=4,stride=4),
              LayerNorm(dims[0],eps=1e-6,data_format="channels_first"))
    self.downsample_layers.append(stem)

    # 对应stage2-stage4前的3个downsample
    for i in range(3):
      downsample_layer = nn.Sequential(LayerNorm(dims[i],eps=1e-6,data_format="channels_first"),
                      nn.Conv2d(dims[i],dims[i+1],kernel_size=2,stride=2))
      self.downsample_layers.append(downsample_layer)

    self.stages = nn.ModuleList()
    dp_rates = [x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]
    cur = 0
    # 构建每个stage中堆叠的block
    for i in range(4):
      stage = nn.Sequential(*[Block(dim=dims[i],drop_rate=dp_rates[cur+j],layer_scale_init_value=layer_scale_init_value)
            for j in range(depths[i])]
      )
      self.stages.append(stage)
      cur += depths[i]

    self.norm = nn.LayerNorm(dims[-1],eps=1e-6)
    self.head = nn.Linear(dims[-1],num_classes)
    self.apply(self._init_weights)
    self.head.weight.data.mul_(head_init_scale)
    self.head.bias.data.mul_(head_init_scale)

  def _init_weights(self,m):
    if isinstance(m,(nn.Conv2d,nn.Linear)):
      nn.init.trunc_normal_(m.weight,std=0.2)
      nn.init.constant_(m.bias,0)

  def forward_features(self,x:torch.Tensor)->torch.Tensor:
    for i in range(4):
      x = self.downsample_layers[i](x)
      x = self.stages[i](x)

    return self.norm(x.mean([-2,-1]))

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.forward_features(x)
    x = self.head(x)
    return x


def convnext_tiny(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
  model = ConvNeXt(depths=[3,3,9,3],dims=[96,192,384,768],num_classes=num_classes)
  return model


def convnext_small(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
  model = ConvNeXt(depths=[3,3,27,3],dims=[96,192,384,768],num_classes=num_classes)
  return model


def convnext_base(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
  # https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
  model = ConvNeXt(depths=[3,3,27,3],dims=[128,256,512,1024],num_classes=num_classes)
  return model


def convnext_large(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth
  # https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
  model = ConvNeXt(depths=[3,3,27,3],dims=[192,384,768,1536],=num_classes)
  return model


def convnext_xlarge(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
  model = ConvNeXt(depths=[3,3,27,3],dims=[256,512,1024,2048],=num_classes)
  return model

Part4 personal experience

Natural language processing and computer vision have some similarities, so Transformer is currently a popular research direction, and there are many places that can be optimized and transformed. However, CNN is more mature in development and has more mature collocations. structure, if necessary, multiple structures can be combined to achieve better results.

Guess you like

Origin blog.csdn.net/qq_55708326/article/details/126352703