Registro de aprendizaje de PVT (transformador de visión piramidal)

Introducción e inspiración

Desde ViT, la investigación sobre el transformador de visión se ha disparado a lo grande. Desde la perspectiva del pensamiento, sigue principalmente dos direcciones. Una es mejorar el efecto de ViT en la clasificación de imágenes; la otra es aplicar ViT a otras tareas de imagen. como la segmentación y En términos de tareas de detección, el PVT (Pyramid Vision Transformer) presentado aquí pertenece a este último. En comparación con ViT, PVT presenta una estructura piramidal similar a CNN, por lo que PVT se utiliza como columna vertebral en tareas de predicción densas (segmentación y detección, etc.) como CNN.

inserte la descripción de la imagen aquí

Ideas de diseño

La idea de diseño de PVT es que para obtener funciones de escala múltiple en CNN, FPN (Feature Pyramid Network) se puede usar para obtener funciones de escala múltiple, por lo que Transformer puede hacer lo mismo, por lo que PVT (Pyramid Vision Transformer ) Se propone el módulo combinado con modelos similares a DETR para lograr una detección de extremo a extremo.
De hecho, la idea de PVT es muy simple, que es combinar Transformer y FPN, y reducir el mapa de características a través de la convolución para reducir los cálculos.
inserte la descripción de la imagen aquí
Combinado con la figura anterior, podemos ver que sus principales novedades son:

  1. En comparación con ViT, utiliza parches de imagen de grano fino y puede aprender representaciones de características de alta resolución.
  2. Basándose en CNN, una estructura piramidal está diseñada para aprender características de múltiples escalas
  3. Se presenta el módulo SRA, que se utiliza principalmente para mejorar el módulo de atención multicabezal para reducir la cantidad de cálculo de QKV

proceso modelo

Los parámetros de inicialización del modelo, que se explicarán más adelante.

def pvt_small():
    model = PyramidVisionTransformer(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])

Primero, nuestra imagen de entrada es torch.Size([2, 3, 224, 224]), es decir, tamaño de lote = 2, canal = 3, W = H = 224 y luego se envía a la etapa 1
:

        x, (H, W) = self.patch_embed1(x)
        pos_embed1 = self._get_pos_embed(self.pos_embed1, self.patch_embed1, H, W)
        x = x + pos_embed1
        x = self.pos_drop1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

operación de parche

En primer lugar, la operación de parche se realiza en la imagen de entrada, lo que completa la segmentación de la imagen y la
etapa de mapeo lineal (conversión de dimensiones) 1. PatchEmbed se define como:

PatchEmbed(
  (proj): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
  (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
  def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        H, W = H // self.patch_size[0], W // self.patch_size[1]
        return x, (H, W)

Se puede ver que el vector enviado al módulo patch_embed primero obtiene sus parámetros (B, C, W, H)
y luego llama self.proja la operación

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

Es una convolución bidimensional:Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))

Tamaño de salida de la capa de convolución: o = ⌊(i + 2p - k) / s⌋ + 1

El relleno predeterminado es 0, por lo que el tamaño de salida es [2, 64, 56, 56]
y luego se realiza la operación de aplanar para obtener [2, 64, 56X56],
y luego la transposición realiza la conversión de dimensión a: [2, 56X56, 64], que es torch.Size([2, 3136, 64])
entonces operación de Normalización y cambio y retorno de tamaño W y H, en este momento W=H=56

código de localización

En términos de codificación de posición, utiliza un método de codificación de posición aprendible.
pos_embed se inicializa con el siguiente método y el tamaño en este momento es:torch.Size([1, 3136, 64])

self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0]))

Luego proceda con:

pos_embed1 = self._get_pos_embed(self.pos_embed1, self.patch_embed1, H, W)

El proceso de procesamiento es:

def _get_pos_embed(self, pos_embed, patch_embed, H, W):
    if H * W == self.patch_embed1.num_patches:
        return pos_embed
    else:
        return F.interpolate(
            pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
            size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)

Luego agregue directamente la información de características semánticas procesadas y la información de codificación de posición, tenga en cuenta:
pos_embed1 en este momento es: torch.Size([1, 3136, 64])x es torch.Size([2, 3136, 64]), y también se puede agregar

x = x + pos_embed1

La siguiente prueba utiliza el mecanismo de transmisión para expandir la dimensión para el cálculo

import torch
data = torch.randn((2, 1, 2, 2))
data1 = torch.randn((1, 1, 2, 2))
print(data)
print(data1)
print(data+data1)

inserte la descripción de la imagen aquí

calculo de atencion

Ingrese al módulo de cálculo de atención, hay múltiples capas de atención en cada etapa

for blk in self.block1:
    x = blk(x, H, W)

Luego comience el cálculo de la atención:
Construcción Q: después de una construcción lineal y cálculos separados (atención de múltiples cabezas)

self.q = nn.Linear(dim, dim, bias=qkv_bias)
 q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

Obtenga q como: torch.Size([2, 1, 3136, 64])
luego ingrese su módulo innovador, que es reducir la muestra de K y V para reducir su número: self.sr_ratio es reducir el tamaño
x a torch.Size ([2, 3136, 64]), primero Después de permutar, la dimensión se transforma en torch.Size([2, 64,3136]), y luego la remodelación es: torch.Size([2, 64, 56, 56])

x_ = x.permute(0, 2, 1).reshape(B, C, H, W)

Luego, la reducción de la dimensionalidad se realiza mediante operaciones de convolución:

self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)

self.sr es: Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
a través del cual 卷积层输出尺寸: o = ⌊(i + 2p - k) / s⌋ + 1
se puede : 7 o 49 El tensor correspondiente es:torch.Size([2, 49, 64])

Luego usa x_ para obtener kv

self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

El kv obtenido es torch.Size([2, 2, 1, 49, 64]), k y v son ambos torch.Size([2, 1, 49, 64]), pero los valores son diferentes, el
código completo es el siguiente:

  if self.sr_ratio > 1:
        x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
        x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
        x_ = self.norm(x_)
        kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    else:
        kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

Luego se realizan una serie de cálculos:

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
@实际为x@y=x.matmul(y)

El resultado x del cálculo final del mecanismo de atención sigue siendotorch.Size([2, 3136, 64])

El código completo del módulo del mecanismo de atención es el siguiente:

def forward(self, x, H, W):
    B, N, C = x.shape
    q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
    if self.sr_ratio > 1:
        x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
        x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
        x_ = self.norm(x_)
        kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    else:
        kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    k, v = kv[0], kv[1]
    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)

    return x

Finalmente, x se restaura a un resultado bidimensional después de la remodelación y luego ingresa a la siguiente etapa para el cálculo

x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

Luego, el tamaño del mapa de características se reduce a través de la convolución, y la pérdida de espacio se compensa con la dimensión. La siguiente etapa x es:

torch.Size([2, 784, 128])

A través de este proceso, a su vez, se reduce su tamaño y también se obtiene a través de él información multiescala. Vale la pena señalar que solo el parche = 4 en la etapa 1, y el parche en las siguientes tres etapas son todos 2, por lo que también se refiere a la convolución, que es una relación de tamaño doble.

En general, el diseño de Pyramid Version Transformer es relativamente fácil de entender. El siguiente blogger comienza con el código para explicar el proceso de construcción de PVT en detalle.

El código completo es el siguiente:

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {
      
      dim} should be divided by num_heads {
      
      num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
            f"img_size {
      
      img_size} should be divided by patch_size {
      
      patch_size}."
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape

        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        H, W = H // self.patch_size[0], W // self.patch_size[1]

        return x, (H, W)


class PyramidVisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], F4=False):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.F4 = F4

        # patch_embed
        self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
                                       embed_dim=embed_dims[0])
        self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0],
                                       embed_dim=embed_dims[1])
        self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1],
                                       embed_dim=embed_dims[2])
        self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2],
                                       embed_dim=embed_dims[3])

        # pos_embed
        self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0]))
        self.pos_drop1 = nn.Dropout(p=drop_rate)
        self.pos_embed2 = nn.Parameter(torch.zeros(1, self.patch_embed2.num_patches, embed_dims[1]))
        self.pos_drop2 = nn.Dropout(p=drop_rate)
        self.pos_embed3 = nn.Parameter(torch.zeros(1, self.patch_embed3.num_patches, embed_dims[2]))
        self.pos_drop3 = nn.Dropout(p=drop_rate)
        self.pos_embed4 = nn.Parameter(torch.zeros(1, self.patch_embed4.num_patches + 1, embed_dims[3]))
        self.pos_drop4 = nn.Dropout(p=drop_rate)

        # transformer encoder
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0
        self.block1 = nn.ModuleList([Block(
            dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])
            for i in range(depths[0])])

        cur += depths[0]
        self.block2 = nn.ModuleList([Block(
            dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[1])
            for i in range(depths[1])])

        cur += depths[1]
        self.block3 = nn.ModuleList([Block(
            dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[2])
            for i in range(depths[2])])

        cur += depths[2]
        self.block4 = nn.ModuleList([Block(
            dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[3])
            for i in range(depths[3])])

        # init weights
        trunc_normal_(self.pos_embed1, std=.02)
        trunc_normal_(self.pos_embed2, std=.02)
        trunc_normal_(self.pos_embed3, std=.02)
        trunc_normal_(self.pos_embed4, std=.02)
        self.apply(self._init_weights)

    def init_weights(self, pretrained=True):
        import torch

        # if isinstance(pretrained, str):
        #     logger = get_root_logger()
        #     load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)

    def reset_drop_path(self, drop_path_rate):
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
        cur = 0
        for i in range(self.depths[0]):
            self.block1[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[0]
        for i in range(self.depths[1]):
            self.block2[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[1]
        for i in range(self.depths[2]):
            self.block3[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[2]
        for i in range(self.depths[3]):
            self.block4[i].drop_path.drop_prob = dpr[cur + i]

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def _get_pos_embed(self, pos_embed, patch_embed, H, W):
        if H * W == self.patch_embed1.num_patches:
            return pos_embed
        else:
            return F.interpolate(
                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
                size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)

    def forward_features(self, x):
        outs = []

        B = x.shape[0]

        # stage 1
        x, (H, W) = self.patch_embed1(x)
        pos_embed1 = self._get_pos_embed(self.pos_embed1, self.patch_embed1, H, W)
        x = x + pos_embed1
        x = self.pos_drop1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 2
        x, (H, W) = self.patch_embed2(x)
        pos_embed2 = self._get_pos_embed(self.pos_embed2, self.patch_embed2, H, W)
        x = x + pos_embed2
        x = self.pos_drop2(x)
        for blk in self.block2:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 3
        x, (H, W) = self.patch_embed3(x)
        pos_embed3 = self._get_pos_embed(self.pos_embed3, self.patch_embed3, H, W)
        x = x + pos_embed3
        x = self.pos_drop3(x)
        for blk in self.block3:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 4
        x, (H, W) = self.patch_embed4(x)
        pos_embed4 = self._get_pos_embed(self.pos_embed4[:, 1:], self.patch_embed4, H, W)
        x = x + pos_embed4
        x = self.pos_drop4(x)
        for blk in self.block4:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        return outs

    def forward(self, x):
        x = self.forward_features(x)

        if self.F4:
            x = x[3:4]

        return x


def _conv_filter(state_dict, patch_size=16):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {
    
    }
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v

    return out_dict

def pvt_small():
    model = PyramidVisionTransformer(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])

    return model

model = pvt_small()
data = torch.randn((2, 3, 224, 224))
feature = model(data)
print(model)
for out in feature:
    print(out.shape)

Supongo que te gusta

Origin blog.csdn.net/pengxiang1998/article/details/130642249
Recomendado
Clasificación