Swin Transformer code implementation details

swin transformer

1.patch-merging part
Insert image description here
code: [amazing]

		x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]  对应图片所有 1 的位置
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]  对应图片所有 3 的位置
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]  对应图片所有 2 的位置
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]  对应图片所有 4 的位置
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C] 拼在一起,通道变为4倍

		x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]  self.reduction = nn.Linear(4*dim, 2*dim, bias=False)一个线性映射使通道变为2倍

2.Create mask part (a little confused)
![Insert picture description here](https://img-blog.csdnimg.cn/ebc36327a9b84806b96d6d50c9f12dcd.pngDivide Insert image description here
windows
Insert image description here

Identical numbers are consecutive areas
Code:

		h_slices = (slice(0, -self.window_size), #切片 [0,-3) 正着数是从第一个开始记为0,倒着数从最后一个开始记为-1
                    slice(-self.window_size, -self.shift_size),# [-3,-1)
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices: # 给区域标号
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
    # 划分window窗口
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]窗口个数,窗口宽,高,通道数
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw] 利用广播机制
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

3. Window attention
relative position encoding
The overall process (excerpted from the blog)
Insert image description here
adds dimensions, as shown in the figure below [The following dimensional operations are amazing! ! !
Insert image description here

Use the broadcast mechanism to subtract to get the relative position encoding (extracted from the B guide video). The
coordinates corresponding to the colors in the figure below are subtracted.
Insert image description here
This is the change before and after the permute transformation, from the separation of the horizontal and vertical coordinates to the summing of the horizontal and vertical coordinates.
Insert image description here

code:

 # 相对位置编码
        # get pair-wise relative position index for each token inside the window
        #首先 生成绝对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])   # 生成网格坐标索引    堆叠
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw] 并展开为2D向量
        # coords_flatten[:, None, :] 在一维处插入新维度  , coords_flatten[:, :, None] 在二维处插入新维度
                                    # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]  利用广播机制 就是通过相减得到他们的相对位置关系
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2] 调换位置
        #把二元索引变成一元索引
        relative_coords[:, :, 0] += self.window_size[0] - 1  # 坐标转换为从0开始
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 #行坐标乘(2M-1)
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw] 最后一个维度求和
        self.register_buffer("relative_position_index", relative_position_index) #注册为不参与网络学习的变量,
                                                    # #作用是根据最终的相对位置索引 找到对应的可学习的相对位置编码

Guess you like

Origin blog.csdn.net/weixin_44040169/article/details/126911018