MICCAI 2022 | PHTrans: Parallelly Aggregating Global and Local Representations for Medical Image Seg

MICCAI 2022 | PHTrans: Parallelly Aggregating Global and Local Representations for Medical Image Segmentation

MICCAI 2022 Parallel Aggregation of Global and Local Representations for Medical Image Segmentation

0 Abstract

The success of Transformer in the field of computer vision has attracted more and more attention from the medical imaging community. Especially in medical image segmentation, many hybrid architectures based on Convolutional Neural Networks (CNNs) and Transformers have emerged and achieved impressive performance. However, most methods that embed modular Transformers into CNNs struggle to realize their full potential. In this paper, we propose a new hybrid architecture, PHTrans, for medical image segmentation, which mixes Transformer and CNN in parallel in main building blocks, generates hierarchical representations from global and local features and adaptively aggregates them, aiming at Take full advantage of their advantages for better segmentation performance. Specifically, PHTrans follows a U-shaped encoder-decoder design, and introduces a parallel hybrid module in the deep stage, where convolutional blocks and an improved 3D Swin Transformer learn local features and global dependencies, respectively , and then pass **sequence-to The -volume (sequence to volume) operation unifies the dimensions of the output for feature aggregation. Extensive experimental results on Multi-Organ Segmentation (MALBCV) and Automated Cardiac Diagnosis Challenge (ACDC) datasets confirm its effectiveness, consistently outperforming SOTA methods.

1 Introduction

Medical image segmentation aims to extract and quantify regions of interest in images of biological tissues/organs, which are critical for disease diagnosis, preoperative planning, and intervention. Thanks to the excellent representation learning ability of deep learning, convolutional neural networks have achieved great success in medical image analysis. Many excellent network models (such as U-Net, 3D U-Net, and Attention U-Net) have appeared, constantly refreshing the performance upper limit of various segmentation tasks. Despite achieving highly competitive results, CNN-based methods lack the ability to model long-term dependencies due to inherent **inductive biases such as locality and translation invariance**. Some researchers increase the size of the convolution kernel, use dilated convolution (Yu, F., Koltun, V.: Multi-scale context aggregation by dilated convolutions. arXiv preprint arXiv:1511.07122 (2015)) and embed self-attention mechanism to alleviate this problem. However, as long as the convolution operation is still the core of the network architecture, it cannot fundamentally solve the problem of lack of global information.

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-QdAHI0JC-1678765444788) (C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\ image-20230313223111123.png)]

Transformers, relying entirely on attention mechanisms to model global dependencies without requiring any convolutional operations, have emerged as an alternative architecture in computer vision (CV) with pre-training on large-scale datasets. Better performance than CNN. Among them, Vision Transformer (ViT) splits an image into a series of tokens and uses stacked Transformer blocks to model its global relationship, which has had a revolutionary impact on the field of CV. Swin Transformer can generate hierarchical feature representations with low computational complexity in movable windows, achieving state-of-the-art performance in various CV tasks. However, the scale of medical image datasets is much smaller than the pre-trained datasets in the above works (such as ImageNet-21k and JFT-300M), because medical images are not always available and require professional annotation. Therefore, the performance of Transformer in medical image segmentation is not ideal. At the same time, many hybrid structures combined by CNN and Transformer have emerged. They provide a compromise solution that combines the advantages of each other and has gradually become a compromise solution for medical image segmentation, without pre-training on large data sets.

This paper summarizes several popular hybrid architectures based on Transformer and CNN in medical image segmentation. These hybrid architectures add a Transformer to a model with a CNN as the backbone, or replace some components of the architecture. For example:

  1. UNETR and Swin UNETR adopt an encoder-decoder structure, where the encoder consists of stacked blocks built from self-attention and multi-layer perceptrons (i.e., Transformer), while the decoder is a stacked convolutional layer, as shown in Figure 1(a ) shown .
  2. TransBTS and Trans UNet introduce a Transformer composed of CNN between the encoder and decoder, as shown in Figure 1(b).
  3. MISSFormer and CoTr bridge all stages from encoder to decoder through Transformer instead of only connecting adjacent stages, which captures multi-scale global dependencies , see Figure 1©.
  4. nnFormer interweaves Transformer and convolutional blocks into a hybrid model , where convolutions encode precise spatial information and self-attention captures global context, as shown in Figure 1(d).

As can be seen from Figure 1, these architectures realize the serial combination of Transformer and CNN from a macro perspective. However, in a serial combination, convolutions and self-attention cannot permeate the entire network architecture, making it difficult to continuously model local and global representations, thus failing to fully exploit their potential.

In this paper, we propose a Parallel Hybrid Transformer (PHTrans) for medical image segmentation, where the main building blocks consist of CNN and Swin Transformer to simultaneously aggregate global and local representations, see Fig. 1(e). In PHTrans, we extend the standard Swin Transformer to a 3D version by extracting 3D patches of partitioned volumes and building a 3D self-attention mechanism. In view of the fact that the hierarchical nature of Swin Transformer can easily utilize advanced dense prediction technologies such as U-Net, we follow the successful u-shaped architecture design and introduce sequence-to-volume (sequence-to-volume) conversion operations to realize Swin Transformer and CNN in Parallel composition within a block. Compared with serial hybrid architectures, PHTrans can independently construct hierarchical local and global representations and fuse them at each stage, fully exploiting the potential of CNN and Transformer. Extensive experiments demonstrate that the proposed method outperforms other competing methods in various medical image segmentation tasks.

2 Method

2.1 Overall Architecture

An overview of the PHTrans structure is shown in Fig. 2(a). PHTrans adopts a U-shaped encoder and decoder design, which mainly consists of a pure convolution module and a parallel hybrid module. Our original intention is to construct a fully hybrid architecture consisting of Transformer and CNN, but due to the high computational complexity of the self-attention mechanism, Transformer cannot directly receive the input of pixels as tokens. In our implementation, cascaded convolutional blocks and downsampling operations are introduced to reduce the spatial size, thereby gradually extracting low-level features with high resolution and obtaining fine spatial information. Likewise, these pure convolution modules are also deployed in the decoder to recover the dimensionality of the original image by upsampling

[External link image transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the image and upload it directly (img-9tdXVkEr-1678765444788) (C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\ image-20230314100834020.png)]

Given an input volume x∈R H×W×D , where H, W, and D denote height, width, and depth, respectively, we first obtain a feature map
f ∈ RH 2 N 1 × W 2 N using several pure convolution modules 1 × D 2 N 1 × 2 N 1 C f ∈ R^{ {H \over2^N{1}}×{W \over2^N{1}}×{D \over2^N{1}}×{ 2^N{1}C}}fR2No. 1H×2No. 1W×2No. 1D×2N 1C
where N1 and C represent the number of modules and the number of base channels, respectively. Then, a parallel hybrid module composed of Transformer and CNN is used to model the hierarchical representation of local features and global features. Take
H 2 N 1 + N 2 × W 2 N 1 + N 2 × D 2 N 1 + N 2 {H \over2^{N1+N2}}×{W \over2^{N1+N2}}×{D \over2^{N1+N2}}2N1 + N2 _ _H×2N1 + N2 _ _W×2N1 + N2 _ _D
is the output resolution, 2 N1+N2 C is the number of channels, repeat the calculation N2 times. The symmetric decoder corresponding to the encoder is also based on a pure convolutional module and a parallel hybrid module, which fuses the semantic information of the encoder through skip connections and addition operations. Furthermore, we use deep supervision at each stage of the decoder during training, resulting in a total of N1 + N2 outputs , where a joint loss consisting of cross-entropy and Dice loss is applied. The architecture of PHTrans is simple and changeable, and the number of each module can be adjusted according to medical image segmentation tasks, namely N 1 , N 2 , M 1 and M 2 . Among them, M 1 and M 2 are the number of Swin Transformer blocks and convolution blocks in the parallel hybrid module.

2.2 Parallel Hybrid Module

The parallel hybrid module is deployed in the deep layer of PHTrans, with the Trans&Conv block as its core, enabling hierarchical aggregation of local and global representations via CNN and Swin Transformer.

2.2.1 Trans&Conv block.

The downscaled feature maps are fed into the Swin Transformer (ST) block and the convolution (Conv) block respectively. We introduce **volume-to-sequence (V2S) and sequence-to-volume (S2V)** operations at the beginning and end of the ST block to achieve volume and sequence conversions, making them consistent with the output produced by the Conv block. dimensional space is consistent. Specifically, V2S is used to reshape the entire volume (3D image) into a sequence of 3D patches with a window size . S2V is the opposite operation.

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-3OYfZ3kL-1678765444789) (C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\ image-20230314103819387.png)]

As shown in Fig. 2(b), an ST block consists of a multi-head self-attention (MSA) module based on a shift window, followed by a 2-layer
MLP with a GELU activation function in between. Add an LN (LayerNorm) layer before each MSA module and each MLP module, and add a residual connection after each module. In M1 consecutive ST blocks, MSAs with regular (Regular) and shifted (Shift) window configurations, that is, W-MSA and SW-MSA, are alternately embedded into ST blocks to achieve cross-window connections while maintaining different Efficient computation of overlapping windows.

For medical image segmentation, we modify the standard ST block into a 3D version, computing self-attention within local 3D windows that divide the volume uniformly in a non-overlapping manner. Assuming x∈R H×W×S×C is the input of the ST block, first reshape it to N×L×C , where N and L = Wh × Ww × Ws represent the number and dimension of the three-dimensional windows, respectively . The self-attention calculation formula for each head is:
A attention ( Q , K , V ) = S oft M ax ( QKT d + B ) V , Attention(Q, K, V)=SoftMax({\frac {QK^T} {\sqrt{d}} +B })V,Attention(Q,K,V)=SoftMax(d QKT+B ) V ,
where Q, K, V∈RL×dare query, key and value matrices, d is the dimension of query/key, and B∈RL×Lis the relative position deviation. We parameterize a smaller bias matrix B ∈ R(2Wh−1)×(2Ww−1)×(2Ws−1), with values ​​in B taken from the B hat.

The convolutional block is repeated M2 times in units of 3 × 3 × 3 convolutional layers, GELU nonlinearity, and instance normalization. The configuration of convolutional blocks is simple and flexible, and any off-the-shelf convolutional network can be applied. Finally, we fuse the outputs of the ST block and the Conv block through an addition operation. The calculation process of the Trans&Conv block in the encoder can be summarized as:
yi = S 2 V ( STM 1 ( V 2 S ( xi − 1 ) ) ) + C onv M 2 ( xi − 1 ) y_i = S2V (ST^{M1} (V 2S(x_{i−1})))+ Conv^{M2}(x_{i−1})yi=S2V(STM1(V2S(xi1)))+ConvM2(xi1)
where xi−1 is the downsampling result of the i−1th stage of the encoder. In the decoder, in addition to skip connections, we supplement the contextual information from the encoder with an addition operation. Therefore,the Trans&Conv block inthe decoder
can be expressed as: zi = S 2 V ( STM 1 ( V 2 S ( xi + 1 + yi ) ) ) + C onv M 2 ( [ xi − 1 , yi ] ) z_i = S2V (ST^{M1}(V 2S(x_{i+1}+y_i)))+ Conv^{M2}([x_{i−1},y_i])zi=S2V(STM1(V2S(xi+1+yi)))+ConvM2([xi1,yi])
where xi+1 is the upsampling result of decoder stage i+1, and yi is the output of encoder stage i.

2.2.2 Downsampling and upsampling.

Downsampling consists of a strided convolution operation and an instance normalization layer, where the number of channels is halved and the spatial size is doubled. Similarly, upsampling is a strided deconvolution layer followed by an instance normalization layer, which doubles the number of feature map channels and halves the spatial size. The stride is usually set to 2 in all dimensions. However, when the 3D medical image is anisotropic, the stride is set to 1 with respect to a specific dimension.

3 experiments

3.1 Dataset

The Multi-Atlas Labeling Beyond the Cranial Vault (BCV) multi-organ segmentation task includes 30 cases, 3779 axial abdominal clinical CT images. Similar to nnformer, the dataset is split into 18 training samples and 12 testing samples. And using the average dice similarity coefficient (DSC) and the average Hausdorff Distance (HD) as evaluation indicators, our method is evaluated in 8 abdominal organs (aorta, gallbladder, spleen, left kidney, right kidney, liver, pancreas, spleen, stomach ) for evaluation.

The Automated Cardiac Diagnosis Challenge (ACDC) dataset was collected from different patients using an MRI scanner. For each patient's MR images, the left ventricle (LV), right ventricle (RV) and myocardium (MYO) were labeled. According to nnformer, 70 samples are divided into training set, 10 samples are divided into verification set, and 20 samples are divided into test set. Average DSC is used to evaluate our method on this dataset.

3.2 Implementation Details

For a fair comparison, we use nnUNet's code framework to evaluate PHTrans with the same performance as CoTr and nnFormer.

All experiments are performed under the default configuration of nnUNet. In PHTrans, we empirically set the hyperparameters [N1, N2, M1, M2] to [2, 4, 2, 2], and adopt the striding strategy of nnU-Net for downsampling and upsampling without careful design . The base of channel C is 24, and the number of heads of multi-head self-attention used in different encoder stages is [3, 6, 12, 24]. For the BCV dataset and the ACDC dataset, we set the size of the 3D window [Wh, Ww, Ws] in the ST block to [3, 6, 6] and [2, 8, 7], respectively.

During the training phase, we randomly crop sub-volumes of size 48×192×192 and 16×256×224 from the BCV and ACDC datasets respectively as input. We implemented PHTrans under PyTorch 1.9 and performed experiments with a single GeForce RTX 3090 GPU.

3.3 Results

3.3.1 Comparison with SOTA

We compare the performance of PHTrans with previous state-of-the-art methods. In addition to the hybrid architecture mentioned in the introduction, it also includes, LeViT-U-Net, Swin-Unet and nnU-Net. Furthermore, we reproduce Swin UNETR by modifying ViT in the UNETR encoding stage to Swin Transformer, and evaluate the performance of UNETR and Swin UNETR in the same way. and used the same dataset partitions as ours.

The segmentation results on the BCV dataset are shown in Table 1.

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-ejMUSyDF-1678765444789) (C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\ image-20230314112505770.png)]

Our PHTrans achieves the best performance of 88.55% (DSC↑) and 8.68 (HD↓), outperforming the previous best model by 0.8% on average DSC and 1.15 on HD.

Representative samples in Figure 3 demonstrate the success of PHTrans in identifying organ details, for example, “stomach” in row 1, 2, and “left kidney” in row 2. The segmentation results on the ACDC dataset are shown in Table 2.

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-OPi2aM8S-1678765444789) (C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\ image-20230314113155161.png)]

Also, PHTrans has the highest average DSC compared to other state-of-the-art methods, which is evident. It is worth mentioning that Swin-Unet, TransUNet, LeViT, and nnFormer use pre-trained weights on ImageNet to initialize their networks, while PHTrans is trained from scratch on both datasets.

In addition, we also compare the number of parameters and the number of flops to evaluate the model complexity of 3D methods nnformer, CoTr, nnU-Net, UNETR, Swin UNETR and PHTrans in BCV experiments. As shown in Table 3, PHTrans has fewer parameters (36.3M), and its FLOPs (187.4G) are significantly lower than CoTr, nnU-Net and Swin UNETR. In summary, the results of PHTrans on BCV and ACDC datasets fully demonstrate its excellent medical image segmentation and generalization capabilities while maintaining a moderate model complexity.

The split behaves as follows

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-kCTVbs3z-1678765444789) (C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\ image-20230314113258169.png)]

3.3.2 Ablation experiment

Based on the improved 3D Swin-Unet, the components of PHTrans are gradually integrated to explore the impact of different components on the performance of the model. Table 4 provides the quantitative results of the ablation studies. "+PCM" means to use stacked pure convolution modules instead of strided convolution operations for patch partitioning, while "w/o PCM" means the opposite. "w/o ST" means parallel hybrid modules in PHTrans remove Swin Transformer blocks, resulting in an architecture similar to nnU-Net.

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-P0NMDqa4-1678765444789) (C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\ image-20230314113417563.png)]

From these results, it can be seen that PCM is able to improve the performance of 3D Swin-Unet and PHTrans, thanks to PCM's ability to capture fine-grained details in the first few stages. In addition, compared with the single architecture, PHTrans brings a more significant performance improvement, and its average DSC is 3.6% and 0.84% ​​higher than that of "3D Swin-Unet+PCM" and "PHTrans w/o ST", respectively, and HD Higher than 10.84 and 5.69. The results show that it is effective to aggregate global and local representations using a parallel combination strategy of CNN and Transformer.

3.3.3 Discussion

In PHTrans, a common Swin Transformer and simple convolutional block are applied, which shows that the significant performance improvement compared with the state-of-the-art Transformer and CNN block comes from the parallel hybrid architecture design. Furthermore, PHTrans is not pre-trained, as there are no large enough general datasets of 3D medical images so far. From the above considerations, in the future, we will carefully design Transformer and CNN blocks, and explore how to pre-train Transformer end-to-end to further improve segmentation performance.

4 Conclusion

In this paper, we propose a parallel hybrid architecture (PHTrans) based on Swin Transformer and CNN for accurate medical image segmentation. Unlike other hybrid architectures that embed modular Transformers into CNNs, PHTrans builds a hybrid module consisting of Swin Transformers and CNNs throughout the model, continuously aggregating hierarchical representations from global and local features, and giving full play to the advantages of both.

Extensive experiments on BCV and ACDC datasets demonstrate that our method outperforms several state-of-the-art alternatives. As a general architecture, PHTrans is very flexible and can be replaced by off-the-shelf convolution and Transformer modules, which opens up new possibilities for more downstream medical image tasks.

Guess you like

Origin blog.csdn.net/qq_45807235/article/details/129522371