Welcome to the official WeChat account of "CVHub"!
Title: STU-Net: Scalable and Transferable Medical Image Segmentation Models Empowered by Large-Scale Supervised Pre-training
Paper: https://arxiv.org/pdf/2304.06716.pdf
Code: https://github.com/Ziyan-Huang/STU-Net
guide
The deep learning model in the field of medical image segmentation is divided into two types based on CNN
andTransformer
. Among them, U-Net
is the pioneer CNN
model , and subsequent research will apply methods such as residual connection, attention mechanism and different feature aggregation strategies on this basis. Recently, visual Transformer
models have been introduced into medical image segmentation, such as using UNER
and SwinUNETR
models employing Transformer
and Swin Transformer
as encoders to extract features, respectively.
However, these existing models are not able to adapt to different computing resources and handle different medical image segmentation tasks. In addition, although large deep learning models have shown very good performance in many application domains, in the field of medical image segmentation, the state-of-the-art models are still very small, with only tens of millions of parameters.
Therefore, this paper proposes a scalable and transferable model STU-Net
and explores the possibility of training large-scale deep learning models on large-scale medical image segmentation datasets.
creative background
Medical image segmentation is an important intermediate step to automatically label anatomical structures and lesions in medical images, and is the key to many downstream clinical tasks. In recent years, various specific medical image segmentation tasks have been extensively studied, and many deep learning based models have achieved great success.
However, these models usually need to be carefully tuned to different tasks, which greatly limits their transferability. Therefore, there is a need for a single model that can simultaneously handle various medical segmentation tasks, including different input modalities ( , CT
, MRI
) PET
and different segmentation targets, such as organs and tumors. The key to solving this problem is to pre-train large models on large-scale datasets to make the models transferable. From a dataset perspective, several public large-scale medical image segmentation datasets are emerging.
In addition, large models usually require more computational cost, especially when 3D high-resolution medical images are used for training, and this problem is exacerbated. Therefore, this paper proposes the hope that this large model can be extended to different sizes to suit different computing budgets.
To achieve this goal, this paper proposes a family of scalable and transferable U-Net
models , called STU-Net
, with parameter sizes ranging from 1400
tens to 14
billions . Furthermore, to ensure the transferability of the models, we pre-train these models on large-scale datasets using supervised learning.
This paper builds these models based on nnU-Net
the framework because it has state-of-the-art baseline performance and is widely used by researchers. There are two obstacles to developing large models using this framework:
- Basic convolutional blocks may not be suitable for scaling.
nnU-Net
Architectures cannot be easily fine-tuned as they are treated as hyperparameters and thus task-specific.
To address these issues, this paper nnU-Net
improves and extends , and proposes a new scalable and transferable large-scale medical image segmentation model STU-Net
. In addition, the effectiveness STU-Net
of and exhibits excellent transfer performance on multiple downstream datasets.
method
STU-Net
is built on top of the nnU-Net
framework that automatically configures task-specific hyperparameters and achieves state-of-the-art performance on a variety of tasks.
nnU-Net Architecture
nnU-Net
A skip-connection based symmetric encoder-decoder architecture with various resolution stages is employed. Each stage consists of two convolutional layers followed by Instance Normalization
and Leaky ReLU
( Conv-IN-LeakyReLU
).
Since it does not contain residual connections, simply stacking more layers at each stage may suffer from gradient diffusion, making the entire model difficult to optimize. This limits the depth nnU-Net
of and further limits its scalability.
On the other hand, nnU-Net
the input patch
size . Then, use the dataset-specific patch
size and spacing to set hyperparameters related to the network architecture, such as the number of resolution stages, convolution kernels, and downsampling/upsampling ratios. Therefore, these architecture-dependent hyperparameters vary between tasks, leading to different network architectures for different tasks. Furthermore, a model trained on one task cannot be directly transferred to other tasks, which limits the model's transferability evaluation.
Improvements based on nnU-Net
nnU-Net
The task-specific hyperparameters of T can be divided into those related to model weight (such as convolution kernel size, resolution series) and independent of model weight (such as pooling kernel size, input image patch size and spacing, etc.).
In order to make the model architecture more suitable for transfer to other tasks, we fixed the hyperparameters related to the model weights, i.e., kept the resolution series of all tasks as , 6
and used isotropic convolution kernels ( 3,3,3
) for all convolutional layers. . For hyperparameters independent of model weights, we nnU-Net
adopted default settings of , to ensure state-of-the-art performance on various tasks. This paper also compares our setup with nnU-Net
and .3D U-Net
basic block
nnU-Net
Each stage of is composed of a basic block, and each basic block Conv-Instance Normalization- LeakyReLU
is composed of two layers. But when increasing the number of basic blocks in each stage, optimization problems arise due to gradient diffusion.
To solve this problem, we introduce residual connections in the base blocks. Furthermore, to make the whole architecture more compact, we also integrate downsampling into the first residual block of each stage. This downsampling block has a residual architecture similar to the conventional residual block, consisting of left and right branches, where the left branch has two 3×3×3
convolutions while the right branch uses a convolution kernel with a stride 2
of1×1×1
. This basic block improvement makes the whole architecture more compact, while also solving the problem of gradient diffusion.
upsampling
nnU-Net
The upsampling of is done by default using transposed convolution ( transpose convolution
). However, for different tasks, the convolution kernel and step size may change within the same resolution stage, which will cause the weight shape of the transposed convolution to be different, resulting in a weight mismatch when transferring weights between different tasks. question.
To solve this problem, we use interpolation ( interpolation
) plus a 1
convolutional 1×1×1
instead of transposed convolution. This weight-free interpolation method can solve the problem of weight shape mismatch. We use nearest neighbor interpolation ( nearest neighbor interpolation
) for upsampling, and experimental results show that nearest neighbor interpolation is not only faster, but also achieves comparable performance to bicubic interpolation ( ).cubic linear interpolation
scaling strategy
Deep networks usually have larger receptive fields and better representation capabilities, while wide networks tend to extract richer multi-scale features in each layer. According EfficientNet
to the research results of , depth scaling and width scaling are not independent, in order to achieve better accuracy and efficiency, it is better to scale the depth and width of the network in a compound way.
To simplify the scaling problem, we adopt a model with a symmetric structure, i.e. scale the encoder and decoder simultaneously, and scale depth and width by the same ratio in each resolution stage. Table 2 shows the different scales STU-Net
of , where the suffix " S,B,L,H
" means " Small, Base, Large, Huge
" respectively.
Supervised pre-training at scale
Total Segmentator
We STU-Net
pre-trained with the dataset, and STU-Net
the final 1×1×1
convolutional layer has 105
channels, corresponding to the total number of target annotation categories Total Segmentator
in .
To make the pretrained models more general and transferable, we make some modifications to the standard training procedure nnU-Net
in . Compared to the default training nnU-Net
in , we pre-trained the model by . Furthermore, we find that using mirrored data augmentation improves the transfer performance of the model on downstream tasks.1000
epoch
4000
epoch
The pre-trained model can perform direct inference on CT
a 104
downstream dataset consisting of images and containing upstream class target segmentation classes without further adjustments.
For downstream tasks with new labels or different modalities, we use the trained model as initialization and randomly initialize the segmentation output layer to match the number of target output categories. During fine-tuning, the segmentation head is randomly initialized, while the weights of the remaining layers are loaded from the pre-trained model. These weights are fine-tuned using a smaller learning rate ( 0.1
times) than the segmentation head, leading to better results.
experiment
STU-Net-B
The model outperforms the best model based on and the best model based on and on averageDSC
across , respectively .CNN
nnU-Net
Transformer
SwinUNETR-B 0.36%
4.48%
Further extending our base model to large and extra-large sizes leads to improvements in the average
DSC
scores of1.59%
and , respectively2.94%
.
STU-Net-H
The highest average was achieved across all categories and five subcategory groups inTotal Segmentator
the datasetDSC
. The results show the effectiveness of our architectural improvements tonnU-Net
and scaling strategies.
Total Segmentator
When pretrained with , larger models typically14
have higher averageDSC
scores .
Fine-tuning our model pre-trained on , leads to better segmentation performance
Total Segmentator
compared to models trained from scratch on downstream datasets .STU-Net
It can be seen intuitively that
STU-Net
the segmentation results of the model are better than other models in terms of completeness and fineness, which fully proves the advanced nature and versatilitySTU-Net
in field of medical image segmentation.
Summarize
This paper introduces a scalable and transferable medical image segmentation model based on nnU-Net
the framework STU-Net
. STU-Net
With a maximum of 14
100 million parameters, it is by far the largest medical image segmentation model. Total Segmentator
By training STU-Net
the model on a large-scale dataset, we demonstrate that scaling up the model yields significant performance gains when transferred to various downstream tasks, and this validates the potential of large models in the field of medical image segmentation.
Furthermore, STU-Net-H
the model trained on Total Segmentator
the dataset exhibits strong direct inference and fine-tuning transferability across multiple downstream datasets. This observation underscores the practical value of utilizing large-scale pretrained models for medical image segmentation tasks.
In conclusion, the development of scalable and transferable STU-Net
models is expected to advance medical image segmentation techniques, opening new avenues for research and innovation in the medical image segmentation community.
If you are also interested in the full-stack field of artificial intelligence and computer vision, it is strongly recommended that you pay attention to the informative, interesting, and loving public account "CVHub", which brings you high-quality original, multi-field, and in-depth cutting-edge scientific papers every day Interpretation and industrial mature solutions! Welcome to add the editor's WeChat account: cv_huber, remark "CSDN", join the CVHub official academic & technical exchange group, and discuss more interesting topics together!