Transformador de visión parte 1
1 estructura de red
El modelo ViT no solo es adecuado para el campo de la PNL, sino que también puede lograr buenos resultados en el campo CV.
En el artículo original, el autor comparó tres modelos, uno es ViT, el modelo Transformer "puro"; uno es la red ResNet; el otro es el modelo Hybrid, que es un modelo que combina la CNN tradicional y el Transformer. Finalmente se encuentra que cuando el número de iteraciones es grande, la precisión del modelo ViT excederá la del modelo híbrido.
La arquitectura del modelo ViT (Vision Transformer) es la siguiente:
El modelo primero divide la imagen en varios parches, cada tamaño de parche es 16*16; luego, cada parche se ingresa en la capa de incrustación , y cada parche puede obtener un vector llamado token; luego agregue el token utilizado para la clasificación; luego agregue Position Incrustación para marcar la posición de cada token; luego ingrese estos tokens con información de posición adicional en el codificador de transformador y obtenga el resultado final de la clasificación a través de MLPHead .
1.1 Proyección lineal de parches aplanados (Incrustación层)
Se puede realizar directamente a través de una capa convolucional. Ingrese la secuencia del token, es decir, la matriz bidimensional [num_token, token_dim], y luego empalme los tokens [clase] token y superponga la incrustación de posición. El empalme se puede hacer mediante cat operación, y la superposición está directamente en fase.
Después de los experimentos, si no usa la incrustación de posición, la tasa de precisión disminuirá significativamente, pero el tipo de incrustación de posición que se use tiene poco efecto en la tasa de precisión, y la diferencia en la codificación de posición no es importante, por lo que el valor unidimensional predeterminado en el código fuente se utiliza uno con menos parámetros código de ubicación.
La similitud entre los códigos de posición finalmente aprendidos es la siguiente, y cada fila y cada columna tiene una gran similitud:
1.2 Codificador de transformador
La estructura de capas y la estructura MLP son las siguientes:
Aquí, los parches incrustados se someten a Layer Norm, y luego pasan a la atención de múltiples cabezales, luego se realizan Dropout y Layer Norm, y finalmente se realiza MLP para obtener el Encoder Block, y luego el Encoder Block se apila L veces.
1.3 MLPHead (estructura de capa final para clasificación)
Al entrenar ImageNet21K o conjuntos de datos más grandes, consta de Lineal+función de activación tanh+Lineal.Al migrar a ImageNet1K o a su propio conjunto de datos, solo hay un Lineal.
1.4 Varios tipos de ViT
Hay tres tipos, a saber, base, grande y enorme, y las especificaciones son las siguientes:
- Capas: la cantidad de veces que el bloque del codificador se apila repetidamente en el codificador de transformación
- Tamaño oculto: la longitud del vector atenuada de cada token después de pasar por la capa de incrustación
- Tamaño MLP: el número de primeros nodos completamente conectados del módulo MLP, que es 4 veces el tamaño oculto
- Cabezas: el número de cabezas en la atención de múltiples cabezas
2 Construyendo una red basada en Pytorch
El código proviene de la implementación oficial.
Enlace de aprendizaje: ViT
Enlace de código: (colab) ViT
# Vision Transformer
"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
# 随机深度方法
def drop_path(x,drop_prob:float=0.,training:bool=False):
if drop_prob==0. or not training:
return x
keep_prob = 1-drop_prob
shape = (x.shape[0],)+(1,)*(x.ndim-1) # work with diff dim tensors
random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob)*random_tensor
return output
class DropPath(nn.Module):
def __init__(self,drop_prob=None):
super(DropPath,self).__init__()
self.drop_prob = drop_prob
def forward(self,x):
return drop_path(x,self.drop_prob,self.training)
# Patch Embedding
class PatchEmbed(nn.Module):
def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
super().__init__()
img_size = (img_size,img_size)
patch_size = (patch_size,patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0]//patch_size[0],img_size[1]//patch_size[1])
self.num_patches = self.grid_size[0]*self.grid_size[1]
self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B,C,H,W = x.shape
assert H==self.img_size[0] and W==self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# flatten:[B,C,H,W]->[B,C,HW]
# transpose:[B,C,HW]->[B,HW,C]
x = self.proj(x).flatten(2).transpose(1,2)
x = self.norm(x)
return x
class Attention(nn.Module):
def __init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop_ratio=0.,proj_drop_ratio=0.):
super(Attention,self).__init__()
self.num_heads = num_heads
head_dim = dim//num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim,dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self, x):
B,N,C = x.shape
# 调整维度的位置,方便运算
qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
q,k,v = qkv[0],qkv[1],qkv[2]
# 矩阵乘法
attn = ([email protected](-2,-1))*self.scale # norm
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn@v).transpose(1,2).reshape(B,N,C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features,hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features,out_features)
self.drop = nn.Dropout(drop)
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(Block,self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio>0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim*mlp_ratio)
self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop_ratio)
def forward(self,x):
x = x+self.drop_path(self.attn(self.norm1(x)))
x = x+self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
def __init__(self,img_size=224,patch_size=16,in_c=3,num_classes=1000,
embed_dim=768,depth=12,num_heads=12,mlp_ratio=4.0,qkv_bias=True,
qk_scale=None,representation_size=None,distilled=False,drop_ratio=0.,
attn_drop_ratio=0.,drop_path_ratio=0.,embed_layer=PatchEmbed,
norm_layer=None,act_layer=None):
super(VisionTransformer,self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim=embed_dim
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm,eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1,1,embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
self.blocks = nn.Sequential(*[
Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,
qk_scale=qk_scale,drop_ratio=drop_ratio,attn_drop_ratio=attn_drop_ratio,
drop_path_ratio=dpr[i],norm_layer=norm_layer,act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc",nn.Linear(embed_dim,representation_size)),
("act",nn.Tanh())
]))
else:
self.has_logits = False
self.pre_logits = nn.Identity()
self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim,self.num_classes) if num_classes>0 else nn.Identity()
nn.init.trunc_normal_(self.pos_embed,std=0.02)
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token,std=0.02)
nn.init.trunc_normal_(self.cls_token,std=0.02)
self.apply(_init_vit_weights)
def forward_features(self, x):
# [B,C,H,W]->[B,num_patches,embed_dim]
x = self.patch_embed(x) # [B,196,768]
# [1,1,768]->[B,1,768]
cls_token = self.cls_token.expand(x.shape[0],-1,-1)
if self.dist_token is None:
x = torch.cat((cls_token,x),dim=1) # [B,197,768]
else:
x = torch.cat((cls_token,self.dist_token.expand(x.shape[0],-1,-1),x),dim=1)
x = self.pos_drop(x+self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:,0])
else:
return x[:,0], x[:,1]
def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]),self.head_dist(x[1])
if self.training and not torch.jit.is_scripting():
return x,x_dist
else:
return (x+x_dist)/2
else:
x = self.head(x)
return x
def _init_vit_weights(m):
if isinstance(m,nn.Linear):
nn.init.trunc_normal_(m.weight,std=.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m,nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def vit_base_patch16_224(num_classes:int=1000):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=None,
num_classes=num_classes)
return model
def vit_base_patch16_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes)
return model
def vit_base_patch32_224(num_classes:int=1000):
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=None,
num_classes=num_classes)
return model
def vit_base_patch32_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes)
return model
def vit_large_patch16_224(num_classes:int=1000):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=None,
num_classes=num_classes)
return model
def vit_large_patch16_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=1024 if has_logits else None,
num_classes=num_classes)
return model
def vit_large_patch32_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=32,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=1024 if has_logits else None,
num_classes=num_classes)
return model
def vit_huge_patch14_224_in21k(num_classes:int=21843,has_logits:bool=True):
model = VisionTransformer(img_size=224,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
representation_size=1280 if has_logits else None,
num_classes=num_classes)
return model
Transformador giratorio parte 2
1 estructura de red
1.1 Marco general
En comparación con ViT, Swin Transformer es más jerárquico. A medida que aumenta el número de capas, la intensidad de reducción de resolución continúa aumentando, utiliza ventanas que no se superponen para separar el mapa de características y realiza la autoatención de múltiples cabezales MLP en cada ventana. cálculo, lo que reduce en gran medida la cantidad de cálculo.
El marco general de la red de Swin Transformer es el siguiente:
Para una imagen de tres canales, primero realice la operación de partición de parches y luego reduzca la muestra a través de diferentes etapas de 4. Cada etapa de reducción de muestreo se duplicará, y cada vez que aumente 2 veces, la cantidad de canales también se duplicará correspondientemente , excepto Stage 1. Excepto Linear Embedding, los encabezados de otras Stages son Patch Merging. La operación de partición de parches aquí es dividir primero la imagen con una ventana de 4 x 4 y luego aplanarla; la capa de incrustación lineal desempeña el papel de ajustar la dimensión y realiza el procesamiento de norma de capa en cada canal; se pueden usar ambas estructuras. se logra construyendo una capa convolucional.
1.2 Fusión de parches
El principio de Patch Merging es el siguiente: realiza una operación de reducción de muestreo, que reduce a la mitad la longitud y el ancho del mapa de características y duplica el canal:
1.3 W-MSA
W-MSA es Windows Multi-head Self-Atention. En comparación con el módulo anterior de autoatención de múltiples cabezales, utiliza ventanas que no se superponen para separar el mapa de funciones, y cada ventana realiza cálculos de autoatención de múltiples cabezales para reducir el cantidad de cálculo efecto, pero al mismo tiempo, también evitará la interacción de información antes de la ventana, haciendo que el campo receptivo sea más pequeño.
La cantidad de cálculo de los dos es la siguiente, h y w representan respectivamente la altura y el ancho del mapa de funciones, c representa la profundidad del mapa de funciones y m representa el tamaño de cada ventana.
1.4 SW-MSA
SW-MSA es autoatención de múltiples cabezales de ventana desplazada. El diagrama esquemático es el siguiente. Sobre la base de W-MSA, ha llevado a cabo un cierto desplazamiento, realizando así la interacción de información entre diferentes ventanas:
1.5 Sesgo de posición relativa (compensación de posición relativa)
La fórmula involucrada es la siguiente, donde B es el desplazamiento de posición relativa:
El diagrama esquemático del desplazamiento de posición relativa es el siguiente:
1.6 Parámetros de configuración específicos
2 Construyendo una red basada en Pytorch
# Swin Transformer
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
- https://arxiv.org/pdf/2103.14030
Code/weights from https://github.com/microsoft/Swin-Transformer
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional
def drop_path_f(x,drop_prob:float=0.,training:bool=False):
if drop_prob==0. or not training:
return x
keep_prob = 1-drop_prob
shape = (x.shape[0],)+(1,)*(x.ndim-1) # work with diff dim
random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob)*random_tensor
return output
class DropPath(nn.Module):
def __init__(self,drop_prob=None):
super(DropPath,self).__init__()
self.drop_prob = drop_prob
def forward(self,x):
return drop_path_f(x,self.drop_prob,self.training)
def window_partition(x,window_size:int):
# 将feature map按照window_size划分成一个个没有重叠的window
B,H,W,C = x.shape
x = x.view(B,H//window_size,window_size,W//window_size,window_size,C)
windows = x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)
return windows
def window_reverse(windows,window_size:int,H:int,W:int):
# 将一个个window还原成一个feature map
B = int(windows.shape[0]/(H*W/window_size/window_size))
x = windows.view(B,H//window_size,W//window_size,window_size,window_size,-1)
x = x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)
return x
class PatchEmbed(nn.Module):
def __init__(self,patch_size=4,in_c=3,embed_dim=96,norm_layer=None):
super().__init__()
patch_size = (patch_size,patch_size)
self.patch_size = patch_size
self.in_chans = in_c
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_,_,H,W = x.shape
# padding
pad_input = (H%self.patch_size[0]!=0) or (W%self.patch_size[1]!=0)
if pad_input:
x = F.pad(x,(0,self.patch_size[1]-W%self.patch_size[1],
0,self.patch_size[0]-H%self.patch_size[0],0,0))
# 下采样patch_size倍
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1,2)
x = self.norm(x)
return x,H,W
class PatchMerging(nn.Module):
def __init__(self,dim,norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4*dim,2*dim,bias=False)
self.norm = norm_layer(4*dim)
def forward(self,x,H,W):
B,L,C = x.shape
assert L==H*W,"input feature has wrong size"
x = x.view(B,H,W,C)
# padding
pad_input = (H%2==1) or (W%2==1)
if pad_input:
x = F.pad(x,(0,0,0,W%2,0,H%2))
x0 = x[:,0::2,0::2,:] # [B,H/2,W/2,C]
x1 = x[:,1::2,0::2,:] # [B,H/2,W/2,C]
x2 = x[:,0::2,1::2,:] # [B,H/2,W/2,C]
x3 = x[:,1::2,1::2,:] # [B,H/2,W/2,C]
x = torch.cat([x0,x1,x2,x3],-1) # [B,H/2,W/2,4*C]
x = x.view(B,-1,4*C) # [B,H/2*W/2,4*C]
x = self.norm(x)
x = self.reduction(x) # [B,H/2*W/2,2*C]
return x
class Mlp(nn.Module):
def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features,hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features,out_features)
self.drop2 = nn.Dropout(drop)
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class WindowAttention(nn.Module):
def __init__(self,dim,window_size,num_heads,qkv_bias=True,attn_drop=0.,proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim//num_heads
self.scale = head_dim**-0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1),num_heads))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h,coords_w],indexing="ij"))
coords_flatten = torch.flatten(coords,1)
relative_coords = coords_flatten[:,:,None]-coords_flatten[:,None,:]
relative_coords = relative_coords.permute(1,2,0).contiguous()
relative_coords[:,:,0] += self.window_size[0]-1
relative_coords[:,:,1] += self.window_size[1]-1
relative_coords[:,:,0] *= 2*self.window_size[1]-1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index",relative_position_index)
self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim,dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.trunc_normal_(self.relative_position_bias_table,std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self,x,mask:Optional[torch.Tensor]=None):
B_,N,C = x.shape
qkv = self.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
q,k,v = qkv.unbind(0)
q = q*self.scale
attn = ([email protected](-2,-1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0]*self.window_size[1],self.window_size[0]*self.window_size[1],-1)
relative_position_bias = relative_position_bias.permute(2,0,1).contiguous()
attn = attn+relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_//nW,nW,self.num_heads,N,N)+mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1,self.num_heads,N,N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn@v).transpose(1,2).reshape(B_,N,C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
def __init__(self,dim,num_heads,window_size=7,shift_size=0,
mlp_ratio=4.,qkv_bias=True,drop=0.,attn_drop=0.,drop_path=0.,
act_layer=nn.GELU,norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0<=self.shift_size<self.window_size,"shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,window_size=(self.window_size,self.window_size),num_heads=num_heads,
qkv_bias=qkv_bias,attn_drop=attn_drop,proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim*mlp_ratio)
self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop)
def forward(self,x,attn_mask):
H,W = self.H, self.W
B,L,C = x.shape
assert L==H*W,"input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B,H,W,C)
# 把feature map给pad到window size的整数倍
pad_l = pad_t = 0
pad_r = (self.window_size-W%self.window_size)%self.window_size
pad_b = (self.window_size-H%self.window_size)%self.window_size
x = F.pad(x,(0,0,pad_l,pad_r,pad_t,pad_b))
_,Hp,Wp,_ = x.shape
if self.shift_size>0:
shifted_x = torch.roll(x,shifts=(-self.shift_size,-self.shift_size),dims=(1,2))
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x,self.window_size)
x_windows = x_windows.view(-1,self.window_size*self.window_size,C)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows,mask=attn_mask)
attn_windows = attn_windows.view(-1,self.window_size,self.window_size,C)
shifted_x = window_reverse(attn_windows,self.window_size,Hp,Wp)
if self.shift_size>0:
x = torch.roll(shifted_x,shifts=(self.shift_size,self.shift_size),dims=(1,2))
else:
x = shifted_x
if pad_r>0 or pad_b>0:
# 把前面pad的数据移除掉
x = x[:,:H,:W,:].contiguous()
x = x.view(B,H*W,C)
# FFN
x = shortcut+self.drop_path(x)
x = x+self.drop_path(self.mlp(self.norm2(x)))
return x
class BasicLayer(nn.Module):
def __init__(self,dim,depth,num_heads,window_size,mlp_ratio=4.,
qkv_bias=True,drop=0.,attn_drop=0.,drop_path=0.,
norm_layer=nn.LayerNorm,downsample=None,use_checkpoint=False):
super().__init__()
self.dim = dim
self.depth = depth
self.window_size = window_size
self.use_checkpoint = use_checkpoint
self.shift_size = window_size//2
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i%2==0) else self.shift_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path,list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
if downsample is not None:
self.downsample = downsample(dim=dim,norm_layer=norm_layer)
else:
self.downsample = None
def create_mask(self,x,H,W):
# 保证Hp和Wp是window_size的整数倍
Hp = int(np.ceil(H/self.window_size))*self.window_size
Wp = int(np.ceil(W/self.window_size))*self.window_size
# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1,Hp,Wp,1),device=x.device)
h_slices = (slice(0,-self.window_size),
slice(-self.window_size,-self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0,-self.window_size),
slice(-self.window_size,-self.shift_size),
slice(-self.shift_size,None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:,h,w,:] = cnt
cnt += 1
mask_windows = window_partition(img_mask,self.window_size)
mask_windows = mask_windows.view(-1,self.window_size*self.window_size)
attn_mask = mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))
return attn_mask
def forward(self,x,H,W):
attn_mask = self.create_mask(x,H,W)
for blk in self.blocks:
blk.H,blk.W = H,W
if not torch.jit.is_scripting() and self.use_checkpoint:
x = checkpoint.checkpoint(blk,x,attn_mask)
else:
x = blk(x,attn_mask)
if self.downsample is not None:
x = self.downsample(x,H,W)
H,W = (H+1)//2,(W+1)//2
return x,H,W
class SwinTransformer(nn.Module):
def __init__(self,patch_size=4,in_chans=3,num_classes=1000,
embed_dim=96,depths=(2,2,6,2),num_heads=(3,6,12,24),
window_size=7,mlp_ratio=4.,qkv_bias=True,
drop_rate=0.,attn_drop_rate=0.,drop_path_rate=0.1,
norm_layer=nn.LayerNorm,patch_norm=True,
use_checkpoint=False,**kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
# stage4输出特征矩阵的channels
self.num_features = int(embed_dim*2**(self.num_layers-1))
self.mlp_ratio = mlp_ratio
# 分割成不重叠的patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,in_c=in_chans,embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layers = BasicLayer(dim=int(embed_dim*2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer+1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer<self.num_layers-1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layers)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features,num_classes) if num_classes>0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self,m):
if isinstance(m,nn.Linear):
nn.init.trunc_normal_(m.weight,std=.02)
if isinstance(m,nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias,0)
elif isinstance(m,nn.LayerNorm):
nn.init.constant_(m.bias,0)
nn.init.constant_(m.weight,1.0)
def forward(self,x):
x,H,W = self.patch_embed(x)
x = self.pos_drop(x)
for layer in self.layers:
x,H,W = layer(x,H,W)
x = self.norm(x)
x = self.avgpool(x.transpose(1,2))
x = torch.flatten(x,1)
x = self.head(x)
return x
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2,2,6,2),
num_heads=(3,6,12,24),num_classes=num_classes,**kwargs)
return model
def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2,2,18,2),
num_heads=(3,6,12,24),num_classes=num_classes,**kwargs)
return model
def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2,2,18,2),
num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
return model
def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2,2,18,2),
num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
return model
def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2,2,18,2),
num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
return model
def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2,2,18,2),
num_heads=(4,8,16,32),num_classes=num_classes,**kwargs)
return model
def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=192,depths=(2,2,18,2),
num_heads=(6,12,24,48),num_classes=num_classes,**kwargs)
return model
def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
model = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=192,depths=(2,2,18,2),
num_heads=(6,12,24,48),num_classes=num_classes,**kwargs)
return model
Parte 3 ConvNext
1 estructura de red
1.1 Estructura utilizada
diseño marco
① relación de etapa
Ajuste los tiempos de apilamiento de ResNet50 de (3,4,6,3) a (3,3,9,3), lo cual es consistente con Swin-T, y el efecto ha mejorado significativamente
② tallo "parcheado"
Reemplazando el vástago (el módulo de reducción de resolución original) con una capa convolucional con un tamaño de núcleo de convolución de 4 y un paso de 4, la tasa de precisión se ha mejorado ligeramente y los FLOP también se han reducido ligeramente.
reSiguiente
En comparación con ResNet, ResNeXt tiene un mejor equilibrio entre FLOP y precisión. Aquí, el autor también usa la convolución DW. Cuando se aumenta el ancho de la función de entrada, la precisión ha mejorado mucho y los FLOP también han aumentado.
cuello de botella invertido
El autor cree que el módulo MLP en el bloque Transformador es muy similar al módulo residual invertido con dos extremos gruesos y un medio delgado, por lo que el bloque de cuello de botella se reemplaza con un módulo residual invertido, la precisión se ha mejorado ligeramente y los FLOP tienen también disminuyó significativamente.
Gran tamaño de kerner
Mueva la convolución DW hacia arriba, antes era convolución 1*1->convolución DW->convolución 1*1, ahora es convolución DW->convolución 1*1->convolución 1*1, y el tamaño del núcleo de convolución DW cambió de 3 *3 a 7*7
Varios diseños Micro por capas
Reemplace ReLU con GELU y reduzca la cantidad de funciones de activación utilizadas, reduzca la cantidad de veces que se usa BN, reemplace BN con LN, acelere la convergencia y reduzca el sobreajuste, y finalmente use una capa de reducción de resolución separada
1.2 Efectos de red
En comparación con el Swin Transformer de la misma escala, ConvNeXt tiene una mayor tasa de precisión y aumenta la cantidad de imágenes por segundo en aproximadamente un 40 %.
1.3 Varias versiones
Donde C representa el canal de cada capa de características de entrada y B representa el número de repeticiones del bloque de cada etapa
2 Construyendo una red basada en Pytorch
# ConvNeXt
"""
original code from facebook research:
https://github.com/facebookresearch/ConvNeXt
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def drop_path(x,drop_prob:float=0.,training:bool=False):
if drop_prob==0. or not training:
return x
keep_prob = 1-drop_prob
shape = (x.shape[0],)+(1,)*(x.ndim-1) # work with diff dim
random_tensor = keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
random_tensor.floor_()
output = x.div(keep_prob)*random_tensor
return output
class DropPath(nn.Module):
def __init__(self,drop_prob=None):
super(DropPath,self).__init__()
self.drop_prob = drop_prob
def forward(self,x):
return drop_path(x,self.drop_prob,self.training)
class LayerNorm(nn.Module):
def __init__(self,normalized_shape,eps=1e-6,data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape),requires_grad=True)
self.bias = nn.Parameter(torch.zeros(normalized_shape),requires_grad=True)
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last","channels_first"]:
raise ValueError(f"not support data format '{self.data_format}'")
self.normalized_shape = (normalized_shape,)
def forward(self,x:torch.Tensor)->torch.Tensor:
if self.data_format=="channels_last":
return F.layer_norm(x,self.normalized_shape,self.weight,self.bias,self.eps)
elif self.data_format=="channels_first":
# [batch_size,channels,height,width]
mean = x.mean(1,keepdim=True)
var = (x-mean).pow(2).mean(1,keepdim=True)
x = (x-mean)/torch.sqrt(var+self.eps)
x = self.weight[:,None,None]*x+self.bias[:,None,None]
return x
class Block(nn.Module):
def __init__(self,dim,drop_rate=0.,layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim,dim,kernel_size=7,padding=3,groups=dim) # DW卷积
self.norm = LayerNorm(dim,eps=1e-6,data_format="channels_last")
self.pwconv1 = nn.Linear(dim,4*dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4*dim,dim)
self.gamma = nn.Parameter(layer_scale_init_value*torch.ones((dim,)),
requires_grad=True) if layer_scale_init_value>0 else None
self.drop_path = DropPath(drop_rate) if drop_rate>0. else nn.Identity()
def forward(self,x:torch.Tensor)->torch.Tensor:
shortcut = x
x = self.dwconv(x)
x = x.permute(0,2,3,1) # [N,C,H,W]->[N,H,W,C]
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma*x
x = x.permute(0,3,1,2) # [N,H,W,C]->[N,C,H,W]
x = shortcut+self.drop_path(x)
return x
class ConvNeXt(nn.Module):
def __init__(self,in_chans:int=3,num_classes:int=1000,depths:list=None,
dims:list=None,drop_path_rate:float=0.,
layer_scale_init_value:float=1e-6,head_init_scale:float=1.):
super().__init__()
self.downsample_layers = nn.ModuleList()
stem = nn.Sequential(nn.Conv2d(in_chans,dims[0],kernel_size=4,stride=4),
LayerNorm(dims[0],eps=1e-6,data_format="channels_first"))
self.downsample_layers.append(stem)
# 对应stage2-stage4前的3个downsample
for i in range(3):
downsample_layer = nn.Sequential(LayerNorm(dims[i],eps=1e-6,data_format="channels_first"),
nn.Conv2d(dims[i],dims[i+1],kernel_size=2,stride=2))
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0,drop_path_rate,sum(depths))]
cur = 0
# 构建每个stage中堆叠的block
for i in range(4):
stage = nn.Sequential(*[Block(dim=dims[i],drop_rate=dp_rates[cur+j],layer_scale_init_value=layer_scale_init_value)
for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1],eps=1e-6)
self.head = nn.Linear(dims[-1],num_classes)
self.apply(self._init_weights)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def _init_weights(self,m):
if isinstance(m,(nn.Conv2d,nn.Linear)):
nn.init.trunc_normal_(m.weight,std=0.2)
nn.init.constant_(m.bias,0)
def forward_features(self,x:torch.Tensor)->torch.Tensor:
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean([-2,-1]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.head(x)
return x
def convnext_tiny(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
model = ConvNeXt(depths=[3,3,9,3],dims=[96,192,384,768],num_classes=num_classes)
return model
def convnext_small(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
model = ConvNeXt(depths=[3,3,27,3],dims=[96,192,384,768],num_classes=num_classes)
return model
def convnext_base(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
# https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
model = ConvNeXt(depths=[3,3,27,3],dims=[128,256,512,1024],num_classes=num_classes)
return model
def convnext_large(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth
# https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
model = ConvNeXt(depths=[3,3,27,3],dims=[192,384,768,1536],=num_classes)
return model
def convnext_xlarge(num_classes: int):
# https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
model = ConvNeXt(depths=[3,3,27,3],dims=[256,512,1024,2048],=num_classes)
return model
experiencia personal part4
El procesamiento del lenguaje natural y la visión por computadora tienen algunas similitudes, por lo que Transformer es actualmente una dirección de investigación popular, y hay muchos lugares que se pueden optimizar y transformar. Sin embargo, CNN tiene un desarrollo más maduro y tiene estructuras de colocación más maduras, si es necesario, se pueden combinar múltiples estructuras para lograr mejores resultados.