Detalles de implementación del código Swin Transformer

transformador giratorio

1.código de pieza de fusión de parches
Insertar descripción de la imagen aquí
: [increíble]

		x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]  对应图片所有 1 的位置
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]  对应图片所有 3 的位置
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]  对应图片所有 2 的位置
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]  对应图片所有 4 的位置
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C] 拼在一起,通道变为4倍

		x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]  self.reduction = nn.Linear(4*dim, 2*dim, bias=False)一个线性映射使通道变为2倍

2. ¡Crea la parte de la máscara (un poco confusa)
! [Inserta una descripción de la imagen aquí](https://img-blog.csdnimg.cn/ebc36327a9b84806b96d6d50c9f12dcd.pngDividir Insertar descripción de la imagen aquí
ventanas
Insertar descripción de la imagen aquí

Los números idénticos son áreas consecutivas
Código:

		h_slices = (slice(0, -self.window_size), #切片 [0,-3) 正着数是从第一个开始记为0,倒着数从最后一个开始记为-1
                    slice(-self.window_size, -self.shift_size),# [-3,-1)
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices: # 给区域标号
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
    # 划分window窗口
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]窗口个数,窗口宽,高,通道数
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw] 利用广播机制
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

3.
Codificación de posición relativa de atención de ventana
El proceso general (extraído del blog)
Insertar descripción de la imagen aquí
agrega dimensiones, como se muestra en la siguiente figura [¡Las siguientes operaciones dimensionales son asombrosas! ! !
Insertar descripción de la imagen aquí

Utilice el mecanismo de transmisión para restar y obtener la codificación de posición relativa (extraída del video de la guía B). Se
restan las coordenadas correspondientes a los colores en la figura a continuación.
Insertar descripción de la imagen aquí
Este es el cambio antes y después de la transformación de permutación, de la separación de coordenadas horizontales y verticales a la suma de las coordenadas horizontales y verticales.
Insertar descripción de la imagen aquí

Código:

 # 相对位置编码
        # get pair-wise relative position index for each token inside the window
        #首先 生成绝对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])   # 生成网格坐标索引    堆叠
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw] 并展开为2D向量
        # coords_flatten[:, None, :] 在一维处插入新维度  , coords_flatten[:, :, None] 在二维处插入新维度
                                    # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]  利用广播机制 就是通过相减得到他们的相对位置关系
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2] 调换位置
        #把二元索引变成一元索引
        relative_coords[:, :, 0] += self.window_size[0] - 1  # 坐标转换为从0开始
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 #行坐标乘(2M-1)
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw] 最后一个维度求和
        self.register_buffer("relative_position_index", relative_position_index) #注册为不参与网络学习的变量,
                                                    # #作用是根据最终的相对位置索引 找到对应的可学习的相对位置编码

Supongo que te gusta

Origin blog.csdn.net/weixin_44040169/article/details/126911018
Recomendado
Clasificación