[Introducción al aprendizaje profundo de OUC] Registro de aprendizaje de la semana 6: Vision Transformer, Swin Transformer y ConvNeXt

Transformador de visión parte 1

1 estructura de red

El modelo ViT no solo es adecuado para el campo de la PNL, sino que también puede lograr buenos resultados en el campo CV.

En el artículo original, el autor comparó tres modelos, uno es ViT, el modelo Transformer "puro"; uno es la red ResNet; el otro es el modelo Hybrid, que es un modelo que combina la CNN tradicional y el Transformer. Finalmente se encuentra que cuando el número de iteraciones es grande, la precisión del modelo ViT excederá la del modelo híbrido.

La arquitectura del modelo ViT (Vision Transformer) es la siguiente: 

El modelo primero divide la imagen en varios parches, cada tamaño de parche es 16*16; luego, cada parche se ingresa en la capa de incrustación , y cada parche puede obtener un vector llamado token; luego agregue el token utilizado para la clasificación; luego agregue Position Incrustación para marcar la posición de cada token; luego ingrese estos tokens con información de posición adicional en el codificador de transformador y obtenga el resultado final de la clasificación a través de MLPHead .

1.1 Proyección lineal de parches aplanados (Incrustación层)

Se puede realizar directamente a través de una capa convolucional. Ingrese la secuencia del token, es decir, la matriz bidimensional [num_token, token_dim], y luego empalme los tokens [clase] token y superponga la incrustación de posición. El empalme se puede hacer mediante cat operación, y la superposición está directamente en fase.

Después de los experimentos, si no usa la incrustación de posición, la tasa de precisión disminuirá significativamente, pero el tipo de incrustación de posición que se use tiene poco efecto en la tasa de precisión, y la diferencia en la codificación de posición no es importante, por lo que el valor unidimensional predeterminado en el código fuente se utiliza uno con menos parámetros código de ubicación.

La similitud entre los códigos de posición finalmente aprendidos es la siguiente, y cada fila y cada columna tiene una gran similitud:

1.2 Codificador de transformador

La estructura de capas y la estructura MLP son las siguientes:

Aquí, los parches incrustados se someten a Layer Norm, y luego pasan a la atención de múltiples cabezales, luego se realizan Dropout y Layer Norm, y finalmente se realiza MLP para obtener el Encoder Block, y luego el Encoder Block se apila L veces.

1.3 MLPHead (estructura de capa final para clasificación)

Al entrenar ImageNet21K o conjuntos de datos más grandes, consta de Lineal+función de activación tanh+Lineal.Al migrar a ImageNet1K o a su propio conjunto de datos, solo hay un Lineal.

1.4 Varios tipos de ViT

Hay tres tipos, a saber, base, grande y enorme, y las especificaciones son las siguientes:

  • Capas: la cantidad de veces que el bloque del codificador se apila repetidamente en el codificador de transformación
  • Tamaño oculto: la longitud del vector atenuada de cada token después de pasar por la capa de incrustación
  • Tamaño MLP: el número de primeros nodos completamente conectados del módulo MLP, que es 4 veces el tamaño oculto
  • Cabezas: el número de cabezas en la atención de múltiples cabezas 

2 Construyendo una red basada en Pytorch

El código proviene de la implementación oficial.

Enlace de aprendizaje: ViT

Enlace de código: (colab) ViT

# Vision Transformer

"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""

from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


# 随机深度方法
def drop_path(x,drop_prob:float=0.,training:bool=False):
  if drop_prob==0. or not training:
    return x
  keep_prob = 1-drop_prob
  shape = (x.shape[0],)+(1,)*(x.ndim-1)  # work with diff dim tensors
  random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
  random_tensor.floor_()  # binarize
  output = x.div(keep_prob)*random_tensor
  return output

class DropPath(nn.Module):
  def __init__(self,drop_prob=None):
    super(DropPath,self).__init__()
    self.drop_prob = drop_prob

  def forward(self,x):
    return drop_path(x,self.drop_prob,self.training)


# Patch Embedding
class PatchEmbed(nn.Module):
  def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
    super().__init__()
    img_size = (img_size,img_size)
    patch_size = (patch_size,patch_size)
    self.img_size = img_size
    self.patch_size = patch_size
    self.grid_size = (img_size[0]//patch_size[0],img_size[1]//patch_size[1])
    self.num_patches = self.grid_size[0]*self.grid_size[1]

    self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

  def forward(self, x):
    B,C,H,W = x.shape
    assert H==self.img_size[0] and W==self.img_size[1], \
        f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

    # flatten:[B,C,H,W]->[B,C,HW]
    # transpose:[B,C,HW]->[B,HW,C]
    x = self.proj(x).flatten(2).transpose(1,2)
    x = self.norm(x)
    return x

class Attention(nn.Module):
  def __init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop_ratio=0.,proj_drop_ratio=0.):
    super(Attention,self).__init__()
    self.num_heads = num_heads
    head_dim = dim//num_heads
    self.scale = qk_scale or head_dim**-0.5
    self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_drop_ratio)
    self.proj = nn.Linear(dim,dim)
    self.proj_drop = nn.Dropout(proj_drop_ratio)

  def forward(self, x):
    B,N,C = x.shape

    # 调整维度的位置,方便运算
    qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
    q,k,v = qkv[0],qkv[1],qkv[2]

    # 矩阵乘法
    attn = ([email protected](-2,-1))*self.scale  # norm
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn@v).transpose(1,2).reshape(B,N,C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

class Mlp(nn.Module):
  def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features
    self.fc1 = nn.Linear(in_features,hidden_features)
    self.act = act_layer()
    self.fc2 = nn.Linear(hidden_features,out_features)
    self.drop = nn.Dropout(drop)

  def forward(self,x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop(x)
    x = self.fc2(x)
    x = self.drop(x)
    return x

class Block(nn.Module):
  def __init__(self,
        dim,
        num_heads,
        mlp_ratio=4.,
        qkv_bias=False,
        qk_scale=None,
        drop_ratio=0.,
        attn_drop_ratio=0.,
        drop_path_ratio=0.,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm):
    super(Block,self).__init__()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
                attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
    
    self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio>0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim*mlp_ratio)
    self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop_ratio)

  def forward(self,x):
    x = x+self.drop_path(self.attn(self.norm1(x)))
    x = x+self.drop_path(self.mlp(self.norm2(x)))
    return x

class VisionTransformer(nn.Module):
  def __init__(self,img_size=224,patch_size=16,in_c=3,num_classes=1000,
        embed_dim=768,depth=12,num_heads=12,mlp_ratio=4.0,qkv_bias=True,
        qk_scale=None,representation_size=None,distilled=False,drop_ratio=0.,
        attn_drop_ratio=0.,drop_path_ratio=0.,embed_layer=PatchEmbed,
        norm_layer=None,act_layer=None):
    super(VisionTransformer,self).__init__()
    self.num_classes = num_classes
    self.num_features = self.embed_dim=embed_dim
    self.num_tokens = 2 if distilled else 1
    norm_layer = norm_layer or partial(nn.LayerNorm,eps=1e-6)
    act_layer = act_layer or nn.GELU

    self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
    num_patches = self.patch_embed.num_patches

    self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
    self.dist_token = nn.Parameter(torch.zeros(1,1,embed_dim)) if distilled else None
    self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
    self.pos_drop = nn.Dropout(p=drop_ratio)

    dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
    self.blocks = nn.Sequential(*[
        Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,
          qk_scale=qk_scale,drop_ratio=drop_ratio,attn_drop_ratio=attn_drop_ratio,
          drop_path_ratio=dpr[i],norm_layer=norm_layer,act_layer=act_layer)
        for i in range(depth)
    ])
    self.norm = norm_layer(embed_dim)

    # Representation layer
    if representation_size and not distilled:
      self.has_logits = True
      self.num_features = representation_size
      self.pre_logits = nn.Sequential(OrderedDict([
          ("fc",nn.Linear(embed_dim,representation_size)),
          ("act",nn.Tanh())
      ]))
    else:
      self.has_logits = False
      self.pre_logits = nn.Identity()

    
    self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()
    self.head_dist = None
    if distilled:
      self.head_dist = nn.Linear(self.embed_dim,self.num_classes) if num_classes>0 else nn.Identity()

    
    nn.init.trunc_normal_(self.pos_embed,std=0.02)
    if self.dist_token is not None:
      nn.init.trunc_normal_(self.dist_token,std=0.02)

    nn.init.trunc_normal_(self.cls_token,std=0.02)
    self.apply(_init_vit_weights)

  def forward_features(self, x):
    # [B,C,H,W]->[B,num_patches,embed_dim]
    x = self.patch_embed(x)  # [B,196,768]
    # [1,1,768]->[B,1,768]
    cls_token = self.cls_token.expand(x.shape[0],-1,-1)
    if self.dist_token is None:
      x = torch.cat((cls_token,x),dim=1)  # [B,197,768]
    else:
      x = torch.cat((cls_token,self.dist_token.expand(x.shape[0],-1,-1),x),dim=1)

    x = self.pos_drop(x+self.pos_embed)
    x = self.blocks(x)
    x = self.norm(x)
    if self.dist_token is None:
      return self.pre_logits(x[:,0])
    else:
      return x[:,0], x[:,1]

  def forward(self, x):
    x = self.forward_features(x)
    if self.head_dist is not None:
      x, x_dist = self.head(x[0]),self.head_dist(x[1])
      if self.training and not torch.jit.is_scripting():
        return x,x_dist
      else:
        return (x+x_dist)/2
    else:
      x = self.head(x)
    return x

def _init_vit_weights(m):
  if isinstance(m,nn.Linear):
    nn.init.trunc_normal_(m.weight,std=.01)
    if m.bias is not None:
      nn.init.zeros_(m.bias)
  elif isinstance(m,nn.Conv2d):
    nn.init.kaiming_normal_(m.weight,mode="fan_out")
    if m.bias is not None:
      nn.init.zeros_(m.bias)
  elif isinstance(m,nn.LayerNorm):
    nn.init.zeros_(m.bias)
    nn.init.ones_(m.weight)

def vit_base_patch16_224(num_classes:int=1000):
  model = VisionTransformer(img_size=224,
                patch_size=16,
                embed_dim=768,
                depth=12,
                num_heads=12,
                representation_size=None,
                num_classes=num_classes)
  return model

def vit_base_patch16_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=16,
                embed_dim=768,
                depth=12,
                num_heads=12,
                representation_size=768 if has_logits else None,
                num_classes=num_classes)
  return model

def vit_base_patch32_224(num_classes:int=1000):
  model = VisionTransformer(img_size=224,
                patch_size=32,
                embed_dim=768,
                depth=12,
                num_heads=12,
                representation_size=None,
                num_classes=num_classes)
  return model


def vit_base_patch32_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=32,
                embed_dim=768,
                depth=12,
                num_heads=12,
                representation_size=768 if has_logits else None,
                num_classes=num_classes)
  return model

def vit_large_patch16_224(num_classes:int=1000):
  model = VisionTransformer(img_size=224,
                patch_size=16,
                embed_dim=1024,
                depth=24,
                num_heads=16,
                representation_size=None,
                num_classes=num_classes)
  return model


def vit_large_patch16_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=16,
                embed_dim=1024,
                depth=24,
                num_heads=16,
                representation_size=1024 if has_logits else None,
                num_classes=num_classes)
  return model

def vit_large_patch32_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=32,
                embed_dim=1024,
                depth=24,
                num_heads=16,
                representation_size=1024 if has_logits else None,
                num_classes=num_classes)
  return model


def vit_huge_patch14_224_in21k(num_classes:int=21843,has_logits:bool=True):
  model = VisionTransformer(img_size=224,
                patch_size=14,
                embed_dim=1280,
                depth=32,
                num_heads=16,
                representation_size=1280 if has_logits else None,
                num_classes=num_classes)
  return model

Transformador giratorio parte 2

1 estructura de red

1.1 Marco general

 En comparación con ViT, Swin Transformer es más jerárquico. A medida que aumenta el número de capas, la intensidad de reducción de resolución continúa aumentando, utiliza ventanas que no se superponen para separar el mapa de características y realiza la autoatención de múltiples cabezales MLP en cada ventana. cálculo, lo que reduce en gran medida la cantidad de cálculo.

El marco general de la red de Swin Transformer es el siguiente:

Para una imagen de tres canales, primero realice la operación de partición de parches y luego reduzca la muestra a través de diferentes etapas de 4. Cada etapa de reducción de muestreo se duplicará, y cada vez que aumente 2 veces, la cantidad de canales también se duplicará correspondientemente , excepto Stage 1. Excepto Linear Embedding, los encabezados de otras Stages son Patch Merging. La operación de partición de parches aquí es dividir primero la imagen con una ventana de 4 x 4 y luego aplanarla; la capa de incrustación lineal desempeña el papel de ajustar la dimensión y realiza el procesamiento de norma de capa en cada canal; se pueden usar ambas estructuras. se logra construyendo una capa convolucional.

1.2 Fusión de parches

El principio de Patch Merging es el siguiente: realiza una operación de reducción de muestreo, que reduce a la mitad la longitud y el ancho del mapa de características y duplica el canal:

1.3 W-MSA

W-MSA es Windows Multi-head Self-Atention. En comparación con el módulo anterior de autoatención de múltiples cabezales, utiliza ventanas que no se superponen para separar el mapa de funciones, y cada ventana realiza cálculos de autoatención de múltiples cabezales para reducir el cantidad de cálculo efecto, pero al mismo tiempo, también evitará la interacción de información antes de la ventana, haciendo que el campo receptivo sea más pequeño.

La cantidad de cálculo de los dos es la siguiente, h y w representan respectivamente la altura y el ancho del mapa de funciones, c representa la profundidad del mapa de funciones y m representa el tamaño de cada ventana.

1.4 SW-MSA

SW-MSA es autoatención de múltiples cabezales de ventana desplazada. El diagrama esquemático es el siguiente. Sobre la base de W-MSA, ha llevado a cabo un cierto desplazamiento, realizando así la interacción de información entre diferentes ventanas:

1.5 Sesgo de posición relativa (compensación de posición relativa)

La fórmula involucrada es la siguiente, donde B es el desplazamiento de posición relativa:

 El diagrama esquemático del desplazamiento de posición relativa es el siguiente:

 

 

 

1.6 Parámetros de configuración específicos

2 Construyendo una red basada en Pytorch

# Swin Transformer

""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
    - https://arxiv.org/pdf/2103.14030

Code/weights from https://github.com/microsoft/Swin-Transformer

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional


def drop_path_f(x,drop_prob:float=0.,training:bool=False):
  if drop_prob==0. or not training:
    return x
  keep_prob = 1-drop_prob
  shape = (x.shape[0],)+(1,)*(x.ndim-1)  # work with diff dim
  random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
  random_tensor.floor_()  # binarize
  output = x.div(keep_prob)*random_tensor
  return output


class DropPath(nn.Module):
  def __init__(self,drop_prob=None):
    super(DropPath,self).__init__()
    self.drop_prob = drop_prob

  def forward(self,x):
    return drop_path_f(x,self.drop_prob,self.training)


def window_partition(x,window_size:int):
  # 将feature map按照window_size划分成一个个没有重叠的window
  B,H,W,C = x.shape
  x = x.view(B,H//window_size,window_size,W//window_size,window_size,C)
  windows = x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)
  return windows


def window_reverse(windows,window_size:int,H:int,W:int):
  # 将一个个window还原成一个feature map
  B = int(windows.shape[0]/(H*W/window_size/window_size))
  x = windows.view(B,H//window_size,W//window_size,window_size,window_size,-1)
  x = x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)
  return x


class PatchEmbed(nn.Module):
  def __init__(self,patch_size=4,in_c=3,embed_dim=96,norm_layer=None):
    super().__init__()
    patch_size = (patch_size,patch_size)
    self.patch_size = patch_size
    self.in_chans = in_c
    self.embed_dim = embed_dim
    self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

  def forward(self, x):
    _,_,H,W = x.shape
    # padding
    pad_input = (H%self.patch_size[0]!=0) or (W%self.patch_size[1]!=0)
    if pad_input:
      x = F.pad(x,(0,self.patch_size[1]-W%self.patch_size[1],
            0,self.patch_size[0]-H%self.patch_size[0],0,0))

    # 下采样patch_size倍
    x = self.proj(x)
    _, _, H, W = x.shape
    x = x.flatten(2).transpose(1,2)
    x = self.norm(x)
    return x,H,W


class PatchMerging(nn.Module):
  def __init__(self,dim,norm_layer=nn.LayerNorm):
    super().__init__()
    self.dim = dim
    self.reduction = nn.Linear(4*dim,2*dim,bias=False)
    self.norm = norm_layer(4*dim)

  def forward(self,x,H,W):
    B,L,C = x.shape
    assert L==H*W,"input feature has wrong size"

    x = x.view(B,H,W,C)

    # padding
    pad_input = (H%2==1) or (W%2==1)
    if pad_input:
      x = F.pad(x,(0,0,0,W%2,0,H%2))

    x0 = x[:,0::2,0::2,:]  # [B,H/2,W/2,C]
    x1 = x[:,1::2,0::2,:]  # [B,H/2,W/2,C]
    x2 = x[:,0::2,1::2,:]  # [B,H/2,W/2,C]
    x3 = x[:,1::2,1::2,:]  # [B,H/2,W/2,C]
    x = torch.cat([x0,x1,x2,x3],-1)  # [B,H/2,W/2,4*C]
    x = x.view(B,-1,4*C)  # [B,H/2*W/2,4*C]

    x = self.norm(x)
    x = self.reduction(x)  # [B,H/2*W/2,2*C]

    return x


class Mlp(nn.Module):
  def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features

    self.fc1 = nn.Linear(in_features,hidden_features)
    self.act = act_layer()
    self.drop1 = nn.Dropout(drop)
    self.fc2 = nn.Linear(hidden_features,out_features)
    self.drop2 = nn.Dropout(drop)

  def forward(self,x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop1(x)
    x = self.fc2(x)
    x = self.drop2(x)
    return x


class WindowAttention(nn.Module):
  def __init__(self,dim,window_size,num_heads,qkv_bias=True,attn_drop=0.,proj_drop=0.):
    super().__init__()
    self.dim = dim
    self.window_size = window_size
    self.num_heads = num_heads
    head_dim = dim//num_heads
    self.scale = head_dim**-0.5

    self.relative_position_bias_table = nn.Parameter(
        torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1),num_heads))

    coords_h = torch.arange(self.window_size[0])
    coords_w = torch.arange(self.window_size[1])
    coords = torch.stack(torch.meshgrid([coords_h,coords_w],indexing="ij"))
    coords_flatten = torch.flatten(coords,1)
    relative_coords = coords_flatten[:,:,None]-coords_flatten[:,None,:]
    relative_coords = relative_coords.permute(1,2,0).contiguous()
    relative_coords[:,:,0] += self.window_size[0]-1
    relative_coords[:,:,1] += self.window_size[1]-1
    relative_coords[:,:,0] *= 2*self.window_size[1]-1
    relative_position_index = relative_coords.sum(-1)
    self.register_buffer("relative_position_index",relative_position_index)

    self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(dim,dim)
    self.proj_drop = nn.Dropout(proj_drop)

    nn.init.trunc_normal_(self.relative_position_bias_table,std=.02)
    self.softmax = nn.Softmax(dim=-1)

  def forward(self,x,mask:Optional[torch.Tensor]=None):
    B_,N,C = x.shape
    qkv = self.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
    q,k,v = qkv.unbind(0)

    q = q*self.scale
    attn = ([email protected](-2,-1))

    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        self.window_size[0]*self.window_size[1],self.window_size[0]*self.window_size[1],-1)
    relative_position_bias = relative_position_bias.permute(2,0,1).contiguous()
    attn = attn+relative_position_bias.unsqueeze(0)

    if mask is not None:
      nW = mask.shape[0]
      attn = attn.view(B_//nW,nW,self.num_heads,N,N)+mask.unsqueeze(1).unsqueeze(0)
      attn = attn.view(-1,self.num_heads,N,N)
      attn = self.softmax(attn)
    else:
      attn = self.softmax(attn)

    attn = self.attn_drop(attn)

    x = (attn@v).transpose(1,2).reshape(B_,N,C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x


class SwinTransformerBlock(nn.Module):
  def __init__(self,dim,num_heads,window_size=7,shift_size=0,
      mlp_ratio=4.,qkv_bias=True,drop=0.,attn_drop=0.,drop_path=0.,
      act_layer=nn.GELU,norm_layer=nn.LayerNorm):
    super().__init__()
    self.dim = dim
    self.num_heads = num_heads
    self.window_size = window_size
    self.shift_size = shift_size
    self.mlp_ratio = mlp_ratio
    assert 0<=self.shift_size<self.window_size,"shift_size must in 0-window_size"

    self.norm1 = norm_layer(dim)
    self.attn = WindowAttention(
        dim,window_size=(self.window_size,self.window_size),num_heads=num_heads,
        qkv_bias=qkv_bias,attn_drop=attn_drop,proj_drop=drop)

    self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim*mlp_ratio)
    self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop)

  def forward(self,x,attn_mask):
    H,W = self.H, self.W
    B,L,C = x.shape
    assert L==H*W,"input feature has wrong size"

    shortcut = x
    x = self.norm1(x)
    x = x.view(B,H,W,C)

    # 把feature map给pad到window size的整数倍
    pad_l = pad_t = 0
    pad_r = (self.window_size-W%self.window_size)%self.window_size
    pad_b = (self.window_size-H%self.window_size)%self.window_size
    x = F.pad(x,(0,0,pad_l,pad_r,pad_t,pad_b))
    _,Hp,Wp,_ = x.shape

    if self.shift_size>0:
      shifted_x = torch.roll(x,shifts=(-self.shift_size,-self.shift_size),dims=(1,2))
    else:
      shifted_x = x
      attn_mask = None

    # partition windows
    x_windows = window_partition(shifted_x,self.window_size)
    x_windows = x_windows.view(-1,self.window_size*self.window_size,C)

    # W-MSA/SW-MSA
    attn_windows = self.attn(x_windows,mask=attn_mask)

    attn_windows = attn_windows.view(-1,self.window_size,self.window_size,C)
    shifted_x = window_reverse(attn_windows,self.window_size,Hp,Wp)

    if self.shift_size>0:
      x = torch.roll(shifted_x,shifts=(self.shift_size,self.shift_size),dims=(1,2))
    else:
      x = shifted_x

    if pad_r>0 or pad_b>0:
      # 把前面pad的数据移除掉
      x = x[:,:H,:W,:].contiguous()

    x = x.view(B,H*W,C)

    # FFN
    x = shortcut+self.drop_path(x)
    x = x+self.drop_path(self.mlp(self.norm2(x)))

    return x


class BasicLayer(nn.Module):
  def __init__(self,dim,depth,num_heads,window_size,mlp_ratio=4.,
      qkv_bias=True,drop=0.,attn_drop=0.,drop_path=0.,
      norm_layer=nn.LayerNorm,downsample=None,use_checkpoint=False):
    super().__init__()
    self.dim = dim
    self.depth = depth
    self.window_size = window_size
    self.use_checkpoint = use_checkpoint
    self.shift_size = window_size//2

    self.blocks = nn.ModuleList([
        SwinTransformerBlock(
            dim=dim,
            num_heads=num_heads,
            window_size=window_size,
            shift_size=0 if (i%2==0) else self.shift_size,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop=drop,
            attn_drop=attn_drop,
            drop_path=drop_path[i] if isinstance(drop_path,list) else drop_path,
            norm_layer=norm_layer)
        for i in range(depth)])

    if downsample is not None:
      self.downsample = downsample(dim=dim,norm_layer=norm_layer)
    else:
      self.downsample = None

  def create_mask(self,x,H,W):
    # 保证Hp和Wp是window_size的整数倍
    Hp = int(np.ceil(H/self.window_size))*self.window_size
    Wp = int(np.ceil(W/self.window_size))*self.window_size
    # 拥有和feature map一样的通道排列顺序,方便后续window_partition
    img_mask = torch.zeros((1,Hp,Wp,1),device=x.device)
    h_slices = (slice(0,-self.window_size),
          slice(-self.window_size,-self.shift_size),
          slice(-self.shift_size, None))
    w_slices = (slice(0,-self.window_size),
          slice(-self.window_size,-self.shift_size),
          slice(-self.shift_size,None))
    cnt = 0
    for h in h_slices:
      for w in w_slices:
        img_mask[:,h,w,:] = cnt
        cnt += 1

    mask_windows = window_partition(img_mask,self.window_size)
    mask_windows = mask_windows.view(-1,self.window_size*self.window_size)
    attn_mask = mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))
    return attn_mask

  def forward(self,x,H,W):
    attn_mask = self.create_mask(x,H,W)
    for blk in self.blocks:
      blk.H,blk.W = H,W
      if not torch.jit.is_scripting() and self.use_checkpoint:
        x = checkpoint.checkpoint(blk,x,attn_mask)
      else:
        x = blk(x,attn_mask)
    if self.downsample is not None:
      x = self.downsample(x,H,W)
      H,W = (H+1)//2,(W+1)//2

    return x,H,W


class SwinTransformer(nn.Module):
  def __init__(self,patch_size=4,in_chans=3,num_classes=1000,
      embed_dim=96,depths=(2,2,6,2),num_heads=(3,6,12,24),
      window_size=7,mlp_ratio=4.,qkv_bias=True,
      drop_rate=0.,attn_drop_rate=0.,drop_path_rate=0.1,
      norm_layer=nn.LayerNorm,patch_norm=True,
      use_checkpoint=False,**kwargs):
    super().__init__()

    self.num_classes = num_classes
    self.num_layers = len(depths)
    self.embed_dim = embed_dim
    self.patch_norm = patch_norm
    # stage4输出特征矩阵的channels
    self.num_features = int(embed_dim*2**(self.num_layers-1))
    self.mlp_ratio = mlp_ratio

    # 分割成不重叠的patches
    self.patch_embed = PatchEmbed(
      patch_size=patch_size,in_c=in_chans,embed_dim=embed_dim,
      norm_layer=norm_layer if self.patch_norm else None)
    self.pos_drop = nn.Dropout(p=drop_rate)

    dpr = [x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]

    # build layers
    self.layers = nn.ModuleList()
    for i_layer in range(self.num_layers):
      layers = BasicLayer(dim=int(embed_dim*2**i_layer),
                    depth=depths[i_layer],
                    num_heads=num_heads[i_layer],
                    window_size=window_size,
                    mlp_ratio=self.mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer+1])],
                    norm_layer=norm_layer,
                    downsample=PatchMerging if (i_layer<self.num_layers-1) else None,
                    use_checkpoint=use_checkpoint)
      self.layers.append(layers)

    self.norm = norm_layer(self.num_features)
    self.avgpool = nn.AdaptiveAvgPool1d(1)
    self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()

    self.apply(self._init_weights)

  def _init_weights(self,m):
    if isinstance(m,nn.Linear):
      nn.init.trunc_normal_(m.weight,std=.02)
      if isinstance(m,nn.Linear) and m.bias is not None:
        nn.init.constant_(m.bias,0)
    elif isinstance(m,nn.LayerNorm):
      nn.init.constant_(m.bias,0)
      nn.init.constant_(m.weight,1.0)

  def forward(self,x):
    x,H,W = self.patch_embed(x)
    x = self.pos_drop(x)

    for layer in self.layers:
      x,H,W = layer(x,H,W)

    x = self.norm(x)
    x = self.avgpool(x.transpose(1,2))
    x = torch.flatten(x,1)
    x = self.head(x)
    return x


def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
  # trained ImageNet-1K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2,2,6,2),
              num_heads=(3,6,12,24),num_classes=num_classes,**kwargs)
  return model


def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
  # trained ImageNet-1K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2,2,18,2),
              num_heads=(3,6,12,24),num_classes=num_classes,**kwargs)
  return model


def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
  # trained ImageNet-1K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2,2,18,2),
              num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
  return model


def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
  # trained ImageNet-1K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2,2,18,2),
              num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
  return model


def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
  # trained ImageNet-22K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2,2,18,2),
              num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
  return model


def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
  # trained ImageNet-22K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2,2,18,2),
              num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
  return model


def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
  # trained ImageNet-22K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=192,depths=(2,2,18,2),
              num_heads=(6,12,24,48),num_classes=num_classes,**kwargs)
  return model


def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
  # trained ImageNet-22K
  # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
  model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=192,depths=(2,2,18,2),
              num_heads=(6,12,24,48),num_classes=num_classes,**kwargs)
  return model

Parte 3 ConvNext

1 estructura de red

1.1 Estructura utilizada

diseño marco

① relación de etapa

Ajuste los tiempos de apilamiento de ResNet50 de (3,4,6,3) a (3,3,9,3), lo cual es consistente con Swin-T, y el efecto ha mejorado significativamente

② tallo "parcheado"

Reemplazando el vástago (el módulo de reducción de resolución original) con una capa convolucional con un tamaño de núcleo de convolución de 4 y un paso de 4, la tasa de precisión se ha mejorado ligeramente y los FLOP también se han reducido ligeramente.

reSiguiente

En comparación con ResNet, ResNeXt tiene un mejor equilibrio entre FLOP y precisión. Aquí, el autor también usa la convolución DW. Cuando se aumenta el ancho de la función de entrada, la precisión ha mejorado mucho y los FLOP también han aumentado.

cuello de botella invertido

El autor cree que el módulo MLP en el bloque Transformador es muy similar al módulo residual invertido con dos extremos gruesos y un medio delgado, por lo que el bloque de cuello de botella se reemplaza con un módulo residual invertido, la precisión se ha mejorado ligeramente y los FLOP tienen también disminuyó significativamente.

Gran tamaño de kerner

Mueva la convolución DW hacia arriba, antes era convolución 1*1->convolución DW->convolución 1*1, ahora es convolución DW->convolución 1*1->convolución 1*1, y el tamaño del núcleo de convolución DW cambió de 3 *3 a 7*7

Varios diseños Micro por capas

Reemplace ReLU con GELU y reduzca la cantidad de funciones de activación utilizadas, reduzca la cantidad de veces que se usa BN, reemplace BN con LN, acelere la convergencia y reduzca el sobreajuste, y finalmente use una capa de reducción de resolución separada

1.2 Efectos de red

En comparación con el Swin Transformer de la misma escala, ConvNeXt tiene una mayor tasa de precisión y aumenta la cantidad de imágenes por segundo en aproximadamente un 40 %.

1.3 Varias versiones

Donde C representa el canal de cada capa de características de entrada y B representa el número de repeticiones del bloque de cada etapa

2 Construyendo una red basada en Pytorch

# ConvNeXt

"""
original code from facebook research:
https://github.com/facebookresearch/ConvNeXt
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


def drop_path(x,drop_prob:float=0.,training:bool=False):
  if drop_prob==0. or not training:
    return x
  keep_prob = 1-drop_prob
  shape = (x.shape[0],)+(1,)*(x.ndim-1)  # work with diff dim
  random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
  random_tensor.floor_()
  output = x.div(keep_prob)*random_tensor
  return output


class DropPath(nn.Module):
  def __init__(self,drop_prob=None):
    super(DropPath,self).__init__()
    self.drop_prob = drop_prob

  def forward(self,x):
    return drop_path(x,self.drop_prob,self.training)


class LayerNorm(nn.Module):
  def __init__(self,normalized_shape,eps=1e-6,data_format="channels_last"):
    super().__init__()
    self.weight = nn.Parameter(torch.ones(normalized_shape),requires_grad=True)
    self.bias = nn.Parameter(torch.zeros(normalized_shape),requires_grad=True)
    self.eps = eps
    self.data_format = data_format
    if self.data_format not in ["channels_last","channels_first"]:
      raise ValueError(f"not support data format '{self.data_format}'")
    self.normalized_shape = (normalized_shape,)

  def forward(self,x:torch.Tensor)->torch.Tensor:
    if self.data_format=="channels_last":
      return F.layer_norm(x,self.normalized_shape,self.weight,self.bias,self.eps)
    elif self.data_format=="channels_first":
      # [batch_size,channels,height,width]
      mean = x.mean(1,keepdim=True)
      var = (x-mean).pow(2).mean(1,keepdim=True)
      x = (x-mean)/torch.sqrt(var+self.eps)
      x = self.weight[:,None,None]*x+self.bias[:,None,None]
      return x


class Block(nn.Module):
  def __init__(self,dim,drop_rate=0.,layer_scale_init_value=1e-6):
    super().__init__()
    self.dwconv = nn.Conv2d(dim,dim,kernel_size=7,padding=3,groups=dim)  # DW卷积
    self.norm = LayerNorm(dim,eps=1e-6,data_format="channels_last")
    self.pwconv1 = nn.Linear(dim,4*dim)
    self.act = nn.GELU()
    self.pwconv2 = nn.Linear(4*dim,dim)
    self.gamma = nn.Parameter(layer_scale_init_value*torch.ones((dim,)),
                  requires_grad=True) if layer_scale_init_value>0 else None
    self.drop_path = DropPath(drop_rate) if drop_rate>0. else nn.Identity()

  def forward(self,x:torch.Tensor)->torch.Tensor:
    shortcut = x
    x = self.dwconv(x)
    x = x.permute(0,2,3,1)  # [N,C,H,W]->[N,H,W,C]
    x = self.norm(x)
    x = self.pwconv1(x)
    x = self.act(x)
    x = self.pwconv2(x)
    if self.gamma is not None:
      x = self.gamma*x
    x = x.permute(0,3,1,2)  # [N,H,W,C]->[N,C,H,W]

    x = shortcut+self.drop_path(x)
    return x


class ConvNeXt(nn.Module):
  def __init__(self,in_chans:int=3,num_classes:int=1000,depths:list=None,
              dims:list=None,drop_path_rate:float=0.,
              layer_scale_init_value:float=1e-6,head_init_scale:float=1.):
    super().__init__()
    self.downsample_layers = nn.ModuleList()
    stem = nn.Sequential(nn.Conv2d(in_chans,dims[0],kernel_size=4,stride=4),
              LayerNorm(dims[0],eps=1e-6,data_format="channels_first"))
    self.downsample_layers.append(stem)

    # 对应stage2-stage4前的3个downsample
    for i in range(3):
      downsample_layer = nn.Sequential(LayerNorm(dims[i],eps=1e-6,data_format="channels_first"),
                      nn.Conv2d(dims[i],dims[i+1],kernel_size=2,stride=2))
      self.downsample_layers.append(downsample_layer)

    self.stages = nn.ModuleList()
    dp_rates = [x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]
    cur = 0
    # 构建每个stage中堆叠的block
    for i in range(4):
      stage = nn.Sequential(*[Block(dim=dims[i],drop_rate=dp_rates[cur+j],layer_scale_init_value=layer_scale_init_value)
            for j in range(depths[i])]
      )
      self.stages.append(stage)
      cur += depths[i]

    self.norm = nn.LayerNorm(dims[-1],eps=1e-6)
    self.head = nn.Linear(dims[-1],num_classes)
    self.apply(self._init_weights)
    self.head.weight.data.mul_(head_init_scale)
    self.head.bias.data.mul_(head_init_scale)

  def _init_weights(self,m):
    if isinstance(m,(nn.Conv2d,nn.Linear)):
      nn.init.trunc_normal_(m.weight,std=0.2)
      nn.init.constant_(m.bias,0)

  def forward_features(self,x:torch.Tensor)->torch.Tensor:
    for i in range(4):
      x = self.downsample_layers[i](x)
      x = self.stages[i](x)

    return self.norm(x.mean([-2,-1]))

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.forward_features(x)
    x = self.head(x)
    return x


def convnext_tiny(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
  model = ConvNeXt(depths=[3,3,9,3],dims=[96,192,384,768],num_classes=num_classes)
  return model


def convnext_small(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
  model = ConvNeXt(depths=[3,3,27,3],dims=[96,192,384,768],num_classes=num_classes)
  return model


def convnext_base(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
  # https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
  model = ConvNeXt(depths=[3,3,27,3],dims=[128,256,512,1024],num_classes=num_classes)
  return model


def convnext_large(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth
  # https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
  model = ConvNeXt(depths=[3,3,27,3],dims=[192,384,768,1536],=num_classes)
  return model


def convnext_xlarge(num_classes: int):
  # https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
  model = ConvNeXt(depths=[3,3,27,3],dims=[256,512,1024,2048],=num_classes)
  return model

experiencia personal part4

El procesamiento del lenguaje natural y la visión por computadora tienen algunas similitudes, por lo que Transformer es actualmente una dirección de investigación popular, y hay muchos lugares que se pueden optimizar y transformar. Sin embargo, CNN tiene un desarrollo más maduro y tiene estructuras de colocación más maduras, si es necesario, se pueden combinar múltiples estructuras para lograr mejores resultados.

Supongo que te gusta

Origin blog.csdn.net/qq_55708326/article/details/126352703
Recomendado
Clasificación