Restormer de recuperación de imágenes: comprensión profunda del documento y el código fuente (comentado exhaustivamente)

1. Papel restaurador

Insertar descripción de la imagen aquí
Trabajo principal:
[1] MDTA (Atención transpuesta de cabeza multi-Dconv), que agrega interacciones de píxeles locales y no locales, puede procesar imágenes de alta resolución de manera efectiva.
[2] GDFN (Gated-Dconv Feed-Forward Network), controla la conversión de funciones, suprime funciones con poco contenido de información y solo permite que información útil ingrese a la siguiente red.
Documento: https://arxiv.org/pdf/2111.09881.pdf
Código fuente:
[1] https://github.com/swz30/Restormer
[2] https://download.csdn.net/download/Wenyuanbo/83592489
Anotación de detalles de la red y código de prueba de entrenamiento personalizado: https://download.csdn.net/download/Wenyuanbo/83617599

2. Estructura de la red Restormer

2.1 Marco general

Insertar descripción de la imagen aquí
El principal punto de innovación del artículo es mejorar tanto MSA como FFN en el Transformer clásico y adoptar la arquitectura Encoder-Decoder. Las operaciones de muestreo ascendente involucradas se implementan usando nn.PixelShuffle () y las operaciones de muestreo de reducción involucradas se implementan usando nn. PixelUnshuffle(). Para lograr esto, el contexto general del artículo es muy claro.

2.2 MDTA

A diferencia del Transformer general, el documento no utiliza el parche común al calcular los tokens a partir de la plantilla de atención, sino el píxel. Primero, use la convolución 1 1 para aumentar la dimensión, luego use la convolución de grupo 3 3 para dividir las características en tres bloques y, finalmente, realice el cálculo clásico de autoatención.
Insertar descripción de la imagen aquí

2.3 GDFN

El artículo propone una red de activación de doble canal para reemplazar a FFN, que realiza un aumento de dimensionalidad de 1 1, luego usa convolución de 3 3 grupos para extraer características, luego usa la activación de la función GELU y, finalmente, la convolución 1 * 1 reduce la salida de dimensión.
Insertar descripción de la imagen aquí

3. Comprensión del código principal

3.1 MDTA

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads  # 注意力头的个数
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))  # 可学习系数
        
        # 1*1 升维
        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        # 3*3 分组卷积
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        # 1*1 卷积
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        b,c,h,w = x.shape  # 输入的结构 batch 数,通道数和高宽

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)  #  第 1 个维度方向切分成 3 块
        # 改变 q, k, v 的结构为 b head c (h w),将每个二维 plane 展平
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)  # C 维度标准化,这里的 C 与通道维度略有不同
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)  # 注意力图(严格来说不算图)
        
        # 将展平后的注意力图恢复
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        # 真正的注意力图
        out = self.project_out(out)
        return out

3.2 GDFN

## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()
        
        # 隐藏层特征维度等于输入维度乘以扩张因子
        hidden_features = int(dim*ffn_expansion_factor)
        # 1*1 升维
        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
        # 3*3 分组卷积
        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
        # 1*1 降维
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)  # 第 1 个维度方向切分成 2 块
        x = F.gelu(x1) * x2  # gelu 相当于 relu+dropout
        x = self.project_out(x)
        return x

3.3 Bloque transformador

## 就是标准的 Transformer 架构
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)  # 层标准化
        self.attn = Attention(dim, num_heads, bias)  # 自注意力
        self.norm2 = LayerNorm(dim, LayerNorm_type)  # 层表转化
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)  # FFN

    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # 残差
        x = x + self.ffn(self.norm2(x))  # 残差

        return x

3.4 Un ejemplo de prueba

model = Restormer()
print(model)  # 打印网络结构

x = torch.randn((1, 3, 64, 64))  #随机生成输入图像
x = model(x)  # 送入网络
print(x.shape) # 打印网络输入的图像结构

referencias

[1] Zamir SW, Arora A, Khan S y otros. Restormer: Efficient Transformer for High-Resolution Image Restoration[J], preimpresión de arXiv arXiv:2111.09881, 2021. [2] Grupo de Teoría de Fronteras de IA de la Universidad Oceánica de China.
[ ARXIV2111 】Restormer: Transformador eficiente para restauración de imágenes de alta resolución.

Conclusión y pensamientos

  1. Los experimentos han demostrado que Restormer ha logrado un muy buen rendimiento en tareas como la eliminación de lluvia de imágenes, eliminación de imágenes borrosas y eliminación de ruido de imágenes. Sin embargo, el documento no compara los parámetros y la eficiencia con otros algoritmos. Hasta donde yo sé, el número de parámetros de MPRNet es 3,64 M, mientras que Restormer es 25,3 M. Si confiamos en acumular parámetros y quemar dinero para obtener resultados SOTA, nuestro equipo no tendrá más remedio que mantenerse alejado.
  2. Hay un factor de expansión γ = 2,66 \gamma=2,66 en GDFNC=2.66La explicación del artículo es muy simple: hacer que los parámetros de la red y la carga computacional sean consistentes con el FFN general .
  3. Para obtener comentarios completos y códigos de prueba y capacitación personalizados, visite: https://download.csdn.net/download/Wenyuanbo/83617599

Supongo que te gusta

Origin blog.csdn.net/Wenyuanbo/article/details/123306095
Recomendado
Clasificación