Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation paper interpretation

Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-TRR298cP-1673942843099) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117101934015. png)]

论文:[2105.05537] Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation (arxiv.org)

代码:HuCaoFighting/Swin-Unet: The codes for the work “Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation” (github.com)

Journal/Conference: ECCV2021

Summary

Over the past few years, convolutional neural networks (CNN) have achieved milestones in medical image analysis. In particular, deep neural networks based on U-shaped structures and skip connections have been widely used in various medical image tasks. However, although CNN has achieved excellent performance, it cannot learn global and long-range semantic information interaction well due to the locality of convolution operations. In this paper, we propose Swin-Unet, a Unet-like pure Transformer for medical image segmentation. The tokenized image patches are input into a transformer-based U-shaped encoder-decoder architecture with skip connections for local and global semantic feature learning. Specifically, we use a hierarchical Swin Transformer with shifted windows as the encoder to extract contextual features. A patch expanding layer decoder based on symmetric Swin Transformer is designed to upsample the feature map and restore the spatial resolution of the feature map. In the case where the input and output are directly downsampled and upsampled 4 times, experiments on multi-organ and cardiac segmentation tasks show that the purely Transformer-based U-shaped encoder-decoder network is better than the full convolution or the combination of Transformer and convolution. method.

1 Introduction

Thanks to the development of deep learning, computer vision technology has been widely used in medical image analysis. Image segmentation is an important part of medical image analysis. In particular, accurate and robust medical image segmentation plays a crucial role in computer-aided diagnosis and image-guided clinical surgery.

Existing medical image segmentation methods mainly rely on U-shaped structure fully convolutional neural network (FCNN). A typical U-shaped network U-Net consists of a symmetric encoder-decoder with skip connections. In the encoder, a series of convolutional layers and successive downsampling layers are used to extract deep features with large receptive fields. Then, the decoder upsamples the extracted depth features to the input resolution for pixel-level semantic prediction, and performs skip connection fusion of high-resolution features at different scales from the encoder to alleviate the spatial information loss caused by downsampling. With such an elegant structural design, U-Net has achieved great success in various medical imaging applications. Following this technical route, algorithms such as 3D U-Net, Res-UNet, U-Net++ and UNet3+ have been developed for image and volumetric segmentation in various medical imaging modalities. The excellent performance of these FCNN-based methods in heart segmentation, organ segmentation and lesion segmentation proves that CNN has strong feature learning and discrimination capabilities.

At present, although CNN-based methods have achieved excellent performance in the field of medical image segmentation, they still cannot fully meet the strict requirements for segmentation accuracy in medical applications. Image segmentation remains a challenging task in medical image analysis. Due to the inherent locality of convolution operations, it is difficult for CNN-based methods to learn explicit global and long-range semantic information interactions . Some research attempts to solve this problem by using atrous convolutional layers, self-attention mechanisms, and image pyramids. However, these methods still have limitations in modeling long-term dependencies. Recently, inspired by the great success of Transformer in the field of natural language processing (NLP), researchers have tried to introduce Transformer into the visual field. In the VIT paper, the Visual Transformer (ViT) is proposed to perform image recognition tasks. Taking 2D image patches with position embeddings as input and pre-trained on large datasets, its performance is comparable to CNN-based methods. In addition, the data-efficient image transformer (DeiT) is proposed in DeiT, which shows that the transformer can be trained on medium-sized data sets, and by combining it with distillation methods, more robust images can be obtained transformer. In the Swin transformer paper, a layered Swin Transformer is developed. [19] used Swin Transformer as the visual backbone network and achieved state-of-the-art performance in image classification, target detection, and semantic segmentation. The success of ViT, DeiT and Swin Transformer in image recognition tasks proves the potential of Transformer in the field of vision.

Motivated by the success of Swin Transformer, we propose Swin-Unet in this work to leverage the power of Transformer for 2D medical image segmentation. As far as we know, Swin-Unet is the first U-shaped architecture purely based on transformer, which consists of encoder, bottleneck, decoder and skip connection. The encoder, bottleneck and decoder are all built on Swin Transformer blocks. Segment the input medical image into non-overlapping image patches. Each patch is treated as a token and fed into a transformer-based encoder to learn deep feature representations. The decoder uses the patch expanding layer to upsample the extracted contextual features, and performs skip connection fusion with the multi-scale features of the encoder to restore the spatial resolution of the feature map, and then perform segmentation prediction. Extensive experiments on multi-organ and cardiac segmentation datasets demonstrate that the method has good segmentation accuracy and robust generalization capabilities. Specifically, our contributions can be summarized as: (1) Based on the Swin Transformer block, we construct a symmetric encoder-decoder architecture with skip connections. In the encoder, a self-attention mechanism from local to global is implemented; in the decoder, global features are upsampled to the input resolution for corresponding pixel-level segmentation prediction. (2) A patch expanding layer is designed to achieve upsampling and increase in feature dimension without using convolution and interpolation operations. (3) It was found in the experiment that skip connections are also effective for Transformer, so a U-shaped encoder-decoder architecture with skip connections purely based on Transformer was finally constructed, named Swin-Unet.

2. Related work

CNN-based methods : Early medical image segmentation methods were mainly based on contours and traditional machine learning-based algorithms. With the development of deep CNN, U-Net was proposed for medical image segmentation. Due to the simple U-shaped structure and superior performance, various Unet-like methods continue to emerge, such as Res-UNet, Dense-UNet, U-Net++, UNet3+, etc. And introduce it into the field of three-dimensional medical image segmentation, such as 3D-Unet [and V-Net]. Currently, CNN-based methods have achieved great success in the field of medical image segmentation due to their powerful representation capabilities.

Vision Transformer : Transformer was originally proposed for machine translation tasks. In the field of natural language processing, transformer-based methods have achieved state-of-the-art performance in various tasks. Motivated by Transformer's success, researchers in 2017 introduced a groundbreaking Visual Transformer (ViT) that achieves an impressive speed-accuracy trade-off in image recognition tasks. Compared with CNN-based methods, the disadvantage of ViT is that it requires pre-training on its own large dataset. To alleviate the difficulty of training ViT, Deit describes several training strategies that make ViT train well on ImageNet. In recent years, some excellent work based on ViT has been completed. It is worth mentioning that an efficient and effective hierarchical visual transformer Swin Transformer is proposed as the visual backbone network. Based on the moving window mechanism, Swin Transformer has achieved state-of-the-art performance on various visual tasks such as image classification, target detection, and semantic segmentation. In this work, we try to use the Swin Transformer block as the basic unit to build a U-shaped encoder-decoder that provides an architecture with skip connections for medical image segmentation, thereby providing a benchmark for the development of Transformer in the field of medical images. Compare.

Self-attention/transformer vs. CNN : In recent years, researchers have tried to introduce self-attention mechanisms into CNN to improve the performance of the network. In some works, skip connections with additional attention gates were integrated using a U-shaped structure for medical image segmentation. However, this is still a CNN-based method. Currently, some people are working hard to combine CNN and Transformer to break the dominance of CNN in medical image segmentation. In some work, researchers have combined Transformer with CNN to form a strong encoder for two-dimensional medical image segmentation. Some researchers have also used the complementarity of Transformer and CNN to improve the segmentation capabilities of the model. Currently, various combinations of Transformer and CNN are used in multi-modal brain tumor segmentation and 3D medical image segmentation. Different from the above methods, we try to explore the application potential of pure Transformer in medical image segmentation.

3. Method

3.1 Overview of model architecture

The overall architecture of our proposed Swin-Unet is shown in Figure 1. Swin-Unet consists of encoder, bottleneck, decoder and skip connection. The basic unit of Swin-Unet is the Swin Transformer block. The encoder segments the medical image into non-overlapping patches, with a patch size of 4 × 4, and converts the input information into sequence embeddings. Through this partitioning method, the feature dimension of each patch is 4 × 4 × 3 = 48. The projected feature dimension linear embedding layer is converted into arbitrary dimensions (denoted as C), and the converted patch token is used to generate a hierarchical feature representation through multiple Swin Transformer blocks and patch merging layers. Among them, the patch merge layer is responsible for downsampling and dimensionality increasing, and the Swin Transformer block is responsible for feature representation learning. Inspired by U-Net, a decoder based on symmetric transformer is designed. The decoder consists of Swin Transformer block and patch expanding layer. The extracted contextual features are fused with the multi-scale features of the encoder through skip connections, which compensates for the loss of spatial information caused by downsampling. Compared with the patch merge layer, the patch expanding layer is specially designed to perform upsampling. The patch expanding layer reshapes feature maps of adjacent dimensions into large feature maps with a resolution of 2x upsampling. Finally, the last patch expansion layer is used for 4× upsampling to restore the resolution of the feature map to the input resolution (W × H), and then a linear projection layer is performed on these upsampled features to output pixel-level segmentation predictions. We explain each block in detail below.

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-SSd7EkWu-1673942843101) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117151939130. png)]

3.2 Swin Transformer block

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-h34PXh4j-1673942843102) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117152654396. png)]

Different from the traditional multi-head self-attention (MSA) module, the Swin Transformer block is constructed based on the shifted window. In Figure 2, two consecutive Swin Transformer blocks are shown. Each Swin Transformer block consists of a LayerNorm (LN) layer, a multi-head self-attention module, a residual connection, and a 2-layer MLP with GELU nonlinearity. The window-based multi-head self-attention (W-MSA) module and the shift window-based multi-head self-attention (SW-MSA) module are applied to these two transformer blocks respectively. Based on this window division mechanism, the continuous swin transformer block can be expressed as:
z ^ l = W − MSA ( LN ( zl − 1 ) ) + zl − 1 \hat z^l=W-MSA(LN(z^{l -1}))+z^{l-1}z^l=WMSA(LN(zl1))+zl1

z l = M L P ( L N ( z ^ l ) ) + z ^ l z^l=MLP(LN(\hat z^l))+\hat z^l zl=MLP(LN(z^l))+z^l

z ^ l + 1 = S W − M S A ( L N ( z l ) ) + z l \hat z^{l+1}=SW-MSA(LN(z^l))+z^l z^l+1=SWMSA(LN(zl))+zl

z l + 1 = M L P ( L N ( z ^ l + 1 ) ) + z ^ l + 1 z^{l+1}=MLP(LN(\hat z^{l+1}))+\hat z^{l+1} zl+1=MLP(LN(z^l+1))+z^l+1

where z ^ l \hat z^lz^lz ^ l \hat z^lz^l represents the (S)W-MSA and MLP modules in thellThe output of l blocks. Similar to previous work, the self-attention calculation is calculated as follows:
A attention ( Q , K , V ) = S oft M ax ( QKT d + B ) V Attention(Q,K,V)=SoftMax(\frac{ QK^T}{\sqrt{d}}+B)VAttention(Q,K,V)=SoftMax(d QKT+B)V
Q , K , V ∈ R M 2 × d Q,K,V \in \R^{M^2 \times d} Q,K,VRM2 ×drefers to the query, key, value vector. M 2 , d M^2,dM2,d refers to the number of patches in the window and the vector dimension of query/key respectively. BBB is the bias, coming from the bias matrixB ^ ∈ R ( z M − 1 ) × ( 2 M + 1 ) \hat B\in \R^{(zM-1) \times (2M+1)}B^R( z M 1 ) × ( 2 M + 1 )

3.3 Encoder

In the encoder, the resolution is H 4 × H 4 \frac{H}{4} \times \frac{H}{4}4H×4HCC _The C- dimensional tokenized input is fed into two consecutive Swin Transformer blocks for representation learning, where the feature dimension and resolution remain unchanged. At the same time, the patch merge layer will reduce the number of tokens (2× downsampling) and increase the feature dimension to 2× the original dimension. This process will be repeated three times in the encoder.

Patch merge layer : The input patch is divided into 4 parts and connected together through the patch merge layer. With such processing, the feature resolution will be reduced by a factor of 2. Since the cascade operation causes the feature dimension to increase by 4 times, a linear layer is added to the cascaded features to make the feature dimension unified to 2 times the original dimension.

3.4 Bottleneck

Since the Transformer depth is too deep to converge, only two consecutive Swin Transformer blocks are used to construct the bottleneck to learn the deep feature representation. In the bottleneck region, the feature dimension and resolution remain unchanged.

3.5 Deccoder

Corresponding to the encoder is the symmetric decoder based on the Swin Transformer block. To this end, we use a patch expand layer in the encoder to upsample the extracted deep features compared to the patch merge layer used in the encoder. The patch expand layer reshapes feature maps of adjacent dimensions into higher-resolution feature maps (2× upsampling), and accordingly reduces the feature dimension to half of the original dimension.

patch expand layer : Taking the first patch expand layer as an example, before upsampling, the input features ( W 32 × W 32 × 8 C \frac{W}{32} \times \frac{W}{32} \ times 8C32W×32W×8 C ) Apply a linear layer to increase the feature dimension to the original dimension (W 32 × W 32 × 16 C \frac{W}{32} \times \frac{W}{32} \times 16C32W×32W×16 C ) 2 times. Then, we use a rearrangement operation to expand the resolution of the input features to 2 times the input resolution and reduce the feature dimension to a quarter of the input dimension (W 32 × W 32 × 16 C → W 16 × W 16 × 4 C \frac{W}{32} \times \frac{W}{32} \times 16C \to \frac{W}{16} \times \frac{W}{16} \times 4C32W×32W×16C16W×16W×4C ) . We discuss the impact of performing upsampling using a patch expand layer in Section 4.5.

3.6 skip connection

Similar to U-Net, skip connections are used to fuse the multi-scale features of the encoder with the upsampled features. We splice shallow features and deep features together to reduce the spatial information loss caused by downsampling. This is followed by a linear layer where the dimensionality of the connected features remains the same as that of the upsampled features. In Section 4.5, we will discuss in detail the impact of the number of skip connections on model performance.

4. Experiment

4.1 Dataset

Synapse multi-organ segmentation dataset (Synapse) : The data set includes 3779 axial abdominal clinical CT images of 30 cases. Divide 18 samples into a training set and 12 samples into a test set. Eight abdominal organs (aorta, gallbladder, spleen, left kidney, right kidney, liver, pancreas, spleen, stomach) were evaluated using Dice-similarity coefficient (DSC) and average Hausdorff Distance (HD) as evaluation indicators. .

Automated cardiac diagnosis challenge dataset (ACDC) : The ACDC dataset is collected from different patients using an MRI scanner. For each patient's MR image, the left ventricle (LV), right ventricle (RV) and myocardium (MYO) were labeled. The data set is divided into 70 training samples, 10 validation samples and 20 testing samples. Only average DSC is used to evaluate the method on this dataset.

4.2 Implementation details

Swin-Unet is implemented based on Python 3.6 and Pytorch 1.7.0. For all training cases, data augmentation such as flipping and rotation is used to increase data diversity. The input image size is set to 224×224, and the patch size is set to 4. We train our model using an Nvidia V100 GPU with 32GB of memory. Weights pretrained on ImageNet are used to initialize model parameters. During training, our backpropagation model was optimized with a batch size of 24 and an SGD optimizer with a momentum of 0.9 and a weight decay of 1e-4.

4.3 Experimental results on Synapse data set

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-HKxPv9bh-1673942843103) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117155759544. png)]

The comparison of the proposed Swin-Unet with previous state-of-the-art methods on the Synapse multi-organ CT dataset is shown in Table 1. Different from TransUnet, we add the test results of our own implementation of U-Net and Att-UNet on the Synapse dataset. Experimental results show that the Unet-like pure transformer method proposed in this article has the best segmentation accuracy, with segmentation accuracies of 79.13% (DSC↑) and 21.55% (HD↓) respectively. Compared with Att-Unet and the recent TransUnet method, although our algorithm does not improve much on the DSC evaluation metric, the accuracy on the HD evaluation metric is improved by about 4% and 10%, which shows that our method can Achieve better edge prediction. The segmentation results of different methods on the Synapse multi-organ CT data set are shown in Figure 3. As can be seen from the figure, CNN-based methods are prone to over-segmentation problems, which may be caused by the locality of the convolution operation. In this work, we demonstrate that by integrating the Transformer with a U-shaped architecture with skip connections, a pure Transformer approach without convolutions can better learn global and long-range semantic information interactions, leading to better segmentation result.

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-0yaDP9fH-1673942843104) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117160030129. png)]

4.4 Experimental results on the ACDC data set

Similar to the Synapse dataset, the proposed Swin-Unet is trained on the ACDC dataset to perform medical image segmentation. The experimental results are shown in Table 2. Using MR mode image data as input, SwinUnet can still achieve excellent performance with an accuracy of 90.00%, indicating that our method has good generalization ability and robustness.

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-7WwVOMgQ-1673942843105) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117160043659. png)]

4.5 Ablation experiment

In order to explore the impact of different factors on model performance, we conducted an ablation study on the Synapse dataset. Specifically, upsampling, number of skip connections, input size and model scale are discussed below.

Upsampling effect : Corresponding to the patch merge layer in the encoder, we specially designed a patch expand layer in the decoder to perform upsampling and increase feature dimensions. To explore the effectiveness of the proposed patch expand layer, we conduct Swin-Unet experiments with bilinear interpolation, transposed convolution and patch expand layer on the Synapse dataset. The experimental results in Table 3 show that the Swin-Unet proposed in this article combined with the patch expansion layer can achieve better segmentation accuracy.

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-NxDTuXCy-1673942843106) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117160318052. png)]

Effect of the number of skip connections : Our SwinUNet’s skip connections are added at 1/4, 1/8 and 1/16 resolution scales. We explore the impact of different skip connections on the segmentation performance of the proposed model by changing the number of skip connections to 0, 1, 2 and 3 respectively. In Table 4, we can see that as the number of skip connections increases, the segmentation performance of the model improves. Therefore, in order to make the model more robust, this article sets the number of skip connections to 3.

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-eQC2DRS5-1673942843107) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117160345995. png)]

Influence of input size : Taking input resolutions of 224 × 224, 384 × 384 as input, the test results of the proposed Swin-Unet are shown in Table 5. When the input size increases from 224 × 224 to 384 × 384 and the patch size remains at 4, the input token sequence of the Transformer will become larger, thereby improving the segmentation performance of the model. However, although the segmentation accuracy of the model is slightly improved, the computational load of the entire network is also significantly increased. In order to ensure the operating efficiency of the algorithm, the experiment in this article uses the 224 × 224 resolution scale as input.

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-GPT8BG5i-1673942843108) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117160432807. png)]

Impact of model scale :, we discuss the impact of network deepening on model performance. As can be seen from Table 6, the increase in model size does not improve the performance of the model, but instead increases the computational cost of the entire network. Considering the balance between accuracy and speed, we adopt a tiny-based model for medical image segmentation.

[The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-0FPhTHjW-1673942843108) (Swin-Unet Unet-like Pure Transformer for Medical Image Segmentation.assets/image-20230117160503236. png)]

4.6 Discussion

It is known that the performance of transformer-based models is severely affected by model pre-training. In this work, we directly use the training weights of Swin Transformer on ImageNet to initialize the network encoder and decoder, which may be a suboptimal solution. This initialization method is simple. In the future, we will explore how to pre-train Transformer end-to-end for medical image segmentation. In addition, since the input images of this article are 2D images and most medical image data are 3D images, we will explore the application of Swin-Unet in three-dimensional medical image segmentation in subsequent research.

5. Summary

This paper introduces a novel pure Transformer-based U-shaped codec for medical image segmentation. In order to give full play to the powerful function of Transformer, we use the Swin Transformer block as the basic unit for interactive learning of feature representation and remote semantic information. Extensive experiments on multi-organ and cardiac segmentation tasks demonstrate that the proposed Swin-Unet has good performance and generalization capabilities.

Guess you like

Origin blog.csdn.net/qq_45041871/article/details/128717551