Transformar【ViT】

referencia

 ¡tutor! La reproducción del blogger es demasiado detallada. Haz una nota.

Pequeño registro de aprendizaje de redes neuronales de capa 67: explicación detallada de la versión de Pytorch de la reaparición del modelo de transformador de visión (VIT)

Resumen de las ideas de innovación del modelo de transformador en visión artificial_Tom Hardy's Blog-CSDN博客

Transformador de visión detallada

vit

preprocesamiento

estructura de red

pensamiento general

Detección de objetos DETR (2020.5) --> Clasificación ViT (2020.10) --> Segmentación SETR (2020.12) --> Transformador Swin (2021.3) -->

La dificultad de usar transformer en el campo de la visión por computadora es que la secuencia es demasiado larga, el trabajo anterior incluye la operación de transformer después de usar CNN para extraer características, la operación de autoatención en una ventana pequeña y el mecanismo de usar atención propia para la longitud y el ancho de la imagen respectivamente.Todos son pasables. La autoatención reemplaza por completo a la convolución y se ha aplicado antes que el campo CV, pero no ha aparecido la red que usa directamente el transformador en el campo visual.

El transformador carece de un sesgo inductivo (la localidad del núcleo de convolución y la no interferencia de la operación de convolución, la traducción deslizante, estos sesgos inductivos le dan a CNN cierta información previa) y requieren una mayor cantidad de entrenamiento.

ViT solo utiliza algunos sesgos inductivos específicos de la imagen para segmentar bloques de imágenes y codificar posiciones, lo que demuestra que el transformador estándar en el campo de la PNL puede realizar tareas visuales tanto como sea posible.

estructura especifica

extracción de características

(224, 224, 3)-->(14, 14, 768)/(196, 768)/(197, 768)-->(197, 768)

1 parche

Convolución 16*16 con un tamaño de paso de 16/mosaico de dimensiones de alto y ancho + Cls Token (1, 768)

Cls Token realizará la extracción de características juntos.

2. Incrustación de posición

Agregue información de ubicación a todas las funciones para que la red tenga la capacidad de distinguir diferentes regiones.

nn.Parameter() genera un tensor aprendible (196, 768) y el gato Token Cls anterior para obtener (197, 768). Luego suma al tensor obtenido por 1.

3. Codificador de transformador

(1. Descripción

1) La L de arriba indica cuántos bloques de transformadores se van a superponer.

2) Después de multiplicar el qk interno de la atención, la caída se establecerá después de la conexión completa. Fuera de la atención, el abandono también se establecerá fuera del mlp.El abandono aquí es establecer todos los valores de píxeles del mapa de características de entrada en 0, y como se superpone el número de capas, la probabilidad de establecer 0 es menor (se especula que la operación aquí es destruir la red. Efecto de ajuste, evitar el sobreajuste, la probabilidad de daño es muy baja, puede estudiar el código fuente). Dropout también se configura antes de ingresar al codificador.

3) La longitud de la secuencia es solo 3 y la longitud característica de cada secuencia unitaria es solo 3. En el codificador de transformador de VIT, la longitud de la secuencia es 197 y la longitud característica de cada secuencia unitaria es 768 // num_heads.Por favor agregue una descripción de la imagen

(2) Detalles de ejecución interna

(3) Módulos específicos

Norma

nn.Norma de capa

Atención de múltiples cabezas

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads  = num_heads
        self.scale      = (dim // num_heads) ** -0.5

        self.qkv        = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop  = nn.Dropout(attn_drop)
        self.proj       = nn.Linear(dim, dim)
        self.proj_drop  = nn.Dropout(proj_drop)

    def forward(self, x):
        # batchsize, 197, 768
        B, N, C     = x.shape
        # 通过全连接层扩充维度为3倍,再将维度拆分为num_head份:3(qkv), batchsize, 12(nums_head), 197(patch), 64(768//12)
        qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # 分配:batchsize, 12(nums_head), 197(patch), 64(768//12)
        q, k, v     = qkv[0], qkv[1], qkv[2]

        # q,k矩阵相乘
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # softmax求每个元素在每个行上的占比是多少
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # 得到的attn与v矩阵相乘
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # 进入线性层
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

MLP

2 nn.Linear(), la función de activación en el medio usa GELU

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
        super().__init__()
        out_features    = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs      = (drop, drop)

        self.fc1    = nn.Linear(in_features, hidden_features)
        self.act    = act_layer()
        self.drop1  = nn.Dropout(drop_probs[0])
        self.fc2    = nn.Linear(hidden_features, out_features)
        self.drop2  = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

agregar

Observe si hay dos líneas de conexión en el medio, la conexión residual.

(Un) parche se multiplica por todos los parches para calcular la importancia, y luego la importancia de la retroalimentación de este parche se multiplica por todos los parches para obtener (un) parche. Reemplazar (uno) con todo es todo el proceso de autoatención.

Resumen: cada parche tiene la suma ponderada de otros parches en relación con el parche.

Clasificación

(1. Descripción

(197, 768)-->(, 768)

En este punto, se creará el token Cls. Como se mencionó anteriormente, el token Cls tiene la información para interactuar con todos los demás parches. Solo haz una conexión completa. hecho.

(2) Detalles de ejecución interna

(3) Módulos específicos

Postprocesamiento

función de pérdida

Supongo que te gusta

Origin blog.csdn.net/qq_41804812/article/details/131083819
Recomendado
Clasificación