Article Directory
foreword
With the deepening of research, SAM
it has been extended to the field of medical image segmentation. However, studies have shown that directly SAM
applying to medical image segmentation, the effect is very poor. However, the difficulty of obtaining medical data and the high cost of annotation urgently require a basic model to check the situation, not only at the level of image segmentation, but also in data annotation.
This paper presents the first model Adaption
that uses the fine-tuning method for SAM
, which achieves astonishing performance on 19 datasets. SAM
This provides an effective reference and guidance for subsequent research and fine-tune work.
Original paper link: Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation
1. Abstract & Introduction
1.1. Abstract
Many recent evaluation tasks have shown that SAM
the performance in medical image segmentation is not satisfactory (such as another paper I previously studied: MSA [1]: Segment Anything Model for Medical Image Analysis: an Experimental Study ). A natural question is how to SAM
transfer the excellent zero-shot capability to the medical image domain.
Therefore, it is necessary to find the missing parts to extend SAM
the segmentation ability of . This paper proposes Medical SAM Adapter (MSA)
to integrate medical-specific domain knowledge into segmentation models through simple and effective adaptation techniques, rather than just fine- SAM
tuning the models. This possible solution, fine-tuning a Adapter
pre-trained SAM
model according to the parameter-efficient fine-tuning paradigm of , shows surprisingly good performance on medical image segmentation.
1.2. Introduction
SAM
As a powerful general-purpose visual segmentation model capable of generating various fine segmentation masks based on user cues, many recent studies have shown that it performs poorly on medical image segmentation.
SAM
The main reason for failure on medical images is lack of training data. Although SAM
a complex and efficient data engine was built during the training process, they collected few medical application cases.
In order to SAM
extend to medical image segmentation, this paper chooses to use Adaption
the parameter efficient fine-tuning ( PEFT
) paradigm of adaptive ( ) to fine-tune the pre-trained SAM
.
Adaption
is a popular and widely used technique in Natural Image Processing (NLP) for fine-tuning basic pre-trained models for specific purposes. The main idea is to insert several parameter-efficient Adapter
modules in the original base model, and then adjust only Adapter
the parameters, freezing all pre-trained parameters.
1.2.1. Why do we need SAM for medical image segmentation?
Interactive segmentation is a paradigm for all segmentation tasks, and SAM
provides an excellent framework, making it a perfect benchmark for implementing hint-based medical image segmentation.
1.2.2. Why fine-tuning?
SAM
The pre-trained model is trained on the world's largest split data set through a well-designed data engine. At the same time, there are quite a few studies showing that pre-training on natural images is also beneficial for medical image segmentation.
1.2.3. Why PEFT and Adaption?
PEFT
is an effective strategy for fine-tuning large base models for specific purposes- Compared with full fine-tuning,
PEFT
most of the parameters are kept frozen, and the parameters learned are greatly reduced, usually less than 5% of the total parameters. - Higher learning efficiency and faster update speed
PEFT
methods generally perform better than full fine-tuning because they avoid catastrophic forgetting and generalize better to out-of-domain scenarios, especially in low-data states
- Compared with full fine-tuning,
Adaption
is an effective tool for fine-tuning large base vision models for downstream tasks
2. Method
2.1. Preliminary: SAM architecture
SAM
Consists of three main parts: Image Encoder, Hint Encoder and Mask Decoder
- The image encoder is based on
MAE
the pre-trained standard vision transformer (ViT
)- The image encoder uses the ViTH/16 variant, which uses 14 × 14 14\times1414×14 window attention and four equally spaced global attention blocks
- The output of the image encoder is 16 16 of the input image16x downsampling embedded
- For
ViT
a detailed introduction, please refer to my other blog: CV-Model [6]: Vision Transformer
- Hint encoders can be sparse (
point
,box
,text
) or dense (mask
) - The mask decoder is a
Transformer decoder block
- Contains a dynamic mask prediction header
SAM
Bidirectional cross-attention is used in each block, one for cue-to-image embeddings and the other for image-to-cue embeddings, to learn the interaction between cue and image embeddings- After running both blocks,
SAM
image embeddings are up-sampled, andMLP
output tokens are mapped to a dynamic linear classifier by , which predicts the target mask for a given image
For Segment Anything Model
a detailed introduction, please refer to my other blog: SAM [1]: Segment Anything
2.2. MSA architecture
To SAM
fine-tune the architecture for medical image segmentation, instead of fully tuning all parameters, we freeze the pre-trained SAM
parameters and insert Adapter
modules at specific positions of the architecture.
Adapter
Is a bottleneck structure, which in turn includes: down-projection, ReLU
activation and up-projection. Down-projection uses a simple MLP
layer to compress a given embedding to a smaller dimension; up-projection uses another MLP
layer to expand the compressed embedding back to its original dimension.
2.2.1. 2D Medical Image Adaption
In SAM
the encoder, we ViT
deploy twoAdapter
Modify the standard ViT block (a)
to get2D Medical Image Adaption (b)
- Place the first
Adapter
after the multi-head attention and before the residual connection - Put the second on the residual path of the layer
Adapter
after multi-head attentionMLP
- Immediately after the second
Adapter
, according to a certain proportionality factor sss scales the embedding- The scaling factor s is introduced to balance task-independent features and task-dependent features
- The default value is 0.1 (the best case in the reference paper)
2.2.2. Decoder Adaption
In SAM
the decoder, we ViT
deploy threeAdapter
Modify the standard ViT block (a)
to getDecoder Adaption (b)
- The first one
Adapter
is deployedprompt-to-image
after the multi-head cross-attention embedding and adds residuals for cue embeddings- This paper uses an alternative down-projection to compress the hint embedding and
ReLU
add itAdapter
to the embedding of - Helps to
Adapter
tune parameters based on hint information, making it more flexible and generic for different modes and downstream tasks
- This paper uses an alternative down-projection to compress the hint embedding and
- The second
Adapter
is deployed in exactly the same way as the encoder and is used to tuneMLP
the enhanced embedding - The third
Adapter
is deployed after the residual connection of the image embedding to prompt cross-attention - Another residual connection and layer normalization are connected after adaptation to output the final result
2.2.3. 3D Medical Image Adaption
Although SAM
can be applied to each slice of the lesion to obtain the final segmentation, it does not consider the correlation in the depth dimension
This paper proposes a new adaptation method, which is inspired by image-to-video adaptation
the specific architecture such as(c)
- In each block, this paper divides the attention operation into two branches: the spatial branch and the depth branch
- For a depth of DDA given 3D sample of D
- Dimensions D × N × LD \times N \times LD×N×The data of L is sent tothe space branch
Multi-head Attention
- where NNN is
embedding
the number ofLLL isembedding
the length of ,DDD is the number of operations - The interaction is applied at N × LN \times LN×Learning and Abstracting Spatial Correlations on L as Embeddings
- In the multi-head attention mechanism, the attention calculation of each head involves three linear transformations: query, key and value, which map the input sequence to different representation subspaces. In these representation subspaces, interactions occur between different heads to learn and capture different features in the model
- Specifically, the interaction occurs in two phases:
- Intra-Head Interactions (
Intra-head Interaction
): Within each head, attention weights are computed by dot product of the query with the key. During this calculation, the query and the key interact to determine the relevance of the query to the key. This way, each head can selectively focus on or ignore information based on different parts of the input sequence - Inter-head interaction (
Inter-head Interaction
): In multi-head attention, the attention weights and corresponding values calculated by each head are combined by weighted summation. This merging can be seen as an interaction between different heads, which enables each head to synthesize different feature representations. Through the combination of multiple heads, the model can obtain a more comprehensive and rich context representation after integrating the information of multiple heads
- Intra-Head Interactions (
- where NNN is
- In the depth branch , the input matrix is first transposed to obtain a dimension of N × D × LN \times D \times LN×D×L data, which is then sent to the same
Multi-head Attention
- While using the same attention mechanism, the interaction is at D × LD \times LD×Applied on L
- In this way, deep dependencies are learned and abstracted
- Dimensions D × N × LD \times N \times LD×N×The data of L is sent tothe space branch
- Finally, transform the results from the depth branch back to their original shapes and add them to the results from the spatial branch
2.3. Training Strategy
2.3.1. Encoder Pretrainig
Pre-training the encoder using medical images
This paper combines a variety of self-supervised learning methods for pre-training
Contrastive Embedding-Mixup (e-Mix)
e-Mix
is a contrastive objective for unsupervised representation learning from text data- The contrastive objective is an objective function for self-supervised learning, often used to train representation learning models on unlabeled data
- In the comparison objective, each sample is transformed into multiple views, which are usually called
anchor
(anchor) andpositive
(positive sample) between - They all come from the same original sample, but may have undergone different transformations (data augmentation) to produce different views
- Then, for each
anchor
, we need to compare it with its corresponding positive samples so that their representations are closer in the representation space and more dispersed from the representations of other samples
- In the comparison objective, each sample is transformed into multiple views, which are usually called
e-Mix
It is possible to do a weighted mix of a batch of raw input embeddings and weight them with different coefficients. The encoder is then trained to produce a mixture embedding vector that approximates the original input's embedding in proportion to the original input's mixture coefficients
- The contrastive objective is an objective function for self-supervised learning, often used to train representation learning models on unlabeled data
Shuffled Embedding Prediction (ShED)
ShED
mixes a subset of embeddings and trains the encoder with a classifier to predict which embeddings are disturbedShED
Aims at learning useful contextual representations from unlabeled text dataShED
The goal is based on the idea of predicting the original order of shuffled words or token embeddings in a sentence
MAE
MAE
mask a given part of the input embedding and train the model to reconstruct them
2.3.2. Training with Prompt
This article adapts to the new medical image data set SAM
. This process is basically SAM
the same as that in . This article only considers 2 kinds of prompts: point
& text
( text
the original author of some codes seems to have not been disclosed yet)
point
- This paper trains this prompt using a combination of random and iterative click sampling strategies
- The specific process is as follows:
- First use random sampling for initialization, then use an iterative sampling procedure to add some
point
- The iterative sampling strategy is similar to interactions with real users, since in practice each new
point
t is placed in the wrong region of the network's predictions using the previous set of hits - Generating random samples and simulating iterative sampling
- First use random sampling for initialization, then use an iterative sampling procedure to add some
text
- In
SAM
, the authors useCLIP
image embeddings of target object crops produced by , as embeddingsCLIP
close to their corresponding textual descriptions or definitions in . However, sinceCLIP
it is rarely trained on medical image datasets, it is difficult to associate organs/lesions on images with corresponding textual definitions - Medical image segmentation requires precise identification and labeling of different structures in the image, and inaccurate labels can lead to wrong diagnosis and treatment
- To overcome this limitation, this paper proposes a different training strategy by generating free text containing target definitions as keywords in Chat-GPT, and then extracting text embeddings using as training
CLIP
cues- By using this approach, it is ensured that textual definitions are relevant to medical images, and that the model can accurately relate different structures in images to their corresponding textual definitions
- In
Summarize
By using (a cost-effective fine-tuning technique), this paper achieves a significant improvement over the PEFT & Adaption
original model and achieves state-of-the-art performance on 19 medical image segmentation tasks in 5 different image modalities.SAM
These results demonstrate the effectiveness of our method for adapting medical images and the potential to transfer powerful general segmentation models to medical applications.