Overall structure:
It can be seen from the above figure that Swin-Unet is mainly composed of Swin Transformer Block, Patch Merging, and Patch Expanding, and the left half is the part of the paper Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. Swin Transformer I have detailed analysis and source code interpretation in another article ( Swin Transformer Interpretation
Patch Expanding
The function of this module is to perform upsampling, which is used to expand the resolution and adjust the number of channels. (The resolution of the last upsampling Patch Expanding is expanded by 4 times).
class PatchExpand(nn.Module):
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
self.norm = norm_layer(dim // dim_scale)
def forward(self, x):
"""
x: B, H*W, c
"""
H, W = self.input_resolution
x = self.expand(x) #[B,H*W,2c]
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) #[B,2H,2W,C//4]
x = x.view(B,-1,C//4) #[B,2H*2W,C//4]
x= self.norm(x)
return x
This operation is essentially the inverse operation of Patch Merging, as shown below.
Experimental results
32-fold downsampling and upsampling of images, and corresponding segmentation of multiple groups of medical organs shows that the Unet network based on pure SwinTransformer outperforms those of full convolution machine or the combination of Transformer and convolution.