MSA【2】:Medical SAM Adapter


foreword

With the deepening of research, SAMit has been extended to the field of medical image segmentation. However, studies have shown that directly SAMapplying to medical image segmentation, the effect is very poor. However, the difficulty of obtaining medical data and the high cost of annotation urgently require a basic model to check the situation, not only at the level of image segmentation, but also in data annotation.

This paper presents the first model Adaptionthat uses the fine-tuning method for SAM, which achieves astonishing performance on 19 datasets. SAMThis provides an effective reference and guidance for subsequent research and fine-tune work.

Original paper link: Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation


1. Abstract & Introduction

1.1. Abstract

Many recent evaluation tasks have shown that SAMthe performance in medical image segmentation is not satisfactory (such as another paper I previously studied: MSA [1]: Segment Anything Model for Medical Image Analysis: an Experimental Study ). A natural question is how to SAMtransfer the excellent zero-shot capability to the medical image domain.

Therefore, it is necessary to find the missing parts to extend SAMthe segmentation ability of . This paper proposes Medical SAM Adapter (MSA)to integrate medical-specific domain knowledge into segmentation models through simple and effective adaptation techniques, rather than just fine- SAMtuning the models. This possible solution, fine-tuning a Adapterpre-trained SAMmodel according to the parameter-efficient fine-tuning paradigm of , shows surprisingly good performance on medical image segmentation.

1.2. Introduction

SAMAs a powerful general-purpose visual segmentation model capable of generating various fine segmentation masks based on user cues, many recent studies have shown that it performs poorly on medical image segmentation.

SAMThe main reason for failure on medical images is lack of training data. Although SAMa complex and efficient data engine was built during the training process, they collected few medical application cases.

insert image description here

In order to SAMextend to medical image segmentation, this paper chooses to use Adaptionthe parameter efficient fine-tuning ( PEFT) paradigm of adaptive ( ) to fine-tune the pre-trained SAM.

Adaptionis a popular and widely used technique in Natural Image Processing (NLP) for fine-tuning basic pre-trained models for specific purposes. The main idea is to insert several parameter-efficient Adaptermodules in the original base model, and then adjust only Adapterthe parameters, freezing all pre-trained parameters.

1.2.1. Why do we need SAM for medical image segmentation?

Interactive segmentation is a paradigm for all segmentation tasks, and SAMprovides an excellent framework, making it a perfect benchmark for implementing hint-based medical image segmentation.

1.2.2. Why fine-tuning?

SAMThe pre-trained model is trained on the world's largest split data set through a well-designed data engine. At the same time, there are quite a few studies showing that pre-training on natural images is also beneficial for medical image segmentation.

1.2.3. Why PEFT and Adaption?

  • PEFTis an effective strategy for fine-tuning large base models for specific purposes
    • Compared with full fine-tuning, PEFTmost of the parameters are kept frozen, and the parameters learned are greatly reduced, usually less than 5% of the total parameters.
    • Higher learning efficiency and faster update speed
    • PEFTmethods generally perform better than full fine-tuning because they avoid catastrophic forgetting and generalize better to out-of-domain scenarios, especially in low-data states
  • Adaptionis an effective tool for fine-tuning large base vision models for downstream tasks

2. Method

2.1. Preliminary: SAM architecture

SAMConsists of three main parts: Image Encoder, Hint Encoder and Mask Decoder

  • The image encoder is based on MAEthe pre-trained standard vision transformer ( ViT)
    • The image encoder uses the ViTH/16 variant, which uses 14 × 14 14\times1414×14 window attention and four equally spaced global attention blocks
    • The output of the image encoder is 16 16 of the input image16x downsampling embedded
    • For ViTa detailed introduction, please refer to my other blog: CV-Model [6]: Vision Transformer
  • Hint encoders can be sparse ( point, box, text) or dense ( mask)
  • The mask decoder is aTransformer decoder block
    • Contains a dynamic mask prediction header
    • SAMBidirectional cross-attention is used in each block, one for cue-to-image embeddings and the other for image-to-cue embeddings, to learn the interaction between cue and image embeddings
    • After running both blocks, SAMimage embeddings are up-sampled, and MLPoutput tokens are mapped to a dynamic linear classifier by , which predicts the target mask for a given image

For Segment Anything Modela detailed introduction, please refer to my other blog: SAM [1]: Segment Anything

2.2. MSA architecture

To SAMfine-tune the architecture for medical image segmentation, instead of fully tuning all parameters, we freeze the pre-trained SAMparameters and insert Adaptermodules at specific positions of the architecture.

AdapterIs a bottleneck structure, which in turn includes: down-projection, ReLUactivation and up-projection. Down-projection uses a simple MLPlayer to compress a given embedding to a smaller dimension; up-projection uses another MLPlayer to expand the compressed embedding back to its original dimension.

insert image description here

2.2.1. 2D Medical Image Adaption

In SAMthe encoder, we ViTdeploy twoAdapter

Modify the standard ViT block (a)to get2D Medical Image Adaption (b)

  • Place the first Adapterafter the multi-head attention and before the residual connection
  • Put the second on the residual path of the layer Adapterafter multi-head attentionMLP
  • Immediately after the second Adapter, according to a certain proportionality factor sss scales the embedding
    • The scaling factor s is introduced to balance task-independent features and task-dependent features
    • The default value is 0.1 (the best case in the reference paper)

2.2.2. Decoder Adaption

In SAMthe decoder, we ViTdeploy threeAdapter

Modify the standard ViT block (a)to getDecoder Adaption (b)

  • The first one Adapteris deployed prompt-to-imageafter the multi-head cross-attention embedding and adds residuals for cue embeddings
    • This paper uses an alternative down-projection to compress the hint embedding and ReLUadd it Adapterto the embedding of
    • Helps to Adaptertune parameters based on hint information, making it more flexible and generic for different modes and downstream tasks
  • The second Adapteris deployed in exactly the same way as the encoder and is used to tune MLPthe enhanced embedding
  • The third Adapteris deployed after the residual connection of the image embedding to prompt cross-attention
  • Another residual connection and layer normalization are connected after adaptation to output the final result

2.2.3. 3D Medical Image Adaption

Although SAMcan be applied to each slice of the lesion to obtain the final segmentation, it does not consider the correlation in the depth dimension

This paper proposes a new adaptation method, which is inspired by image-to-video adaptationthe specific architecture such as(c)

  • In each block, this paper divides the attention operation into two branches: the spatial branch and the depth branch
  • For a depth of DDA given 3D sample of D
    • Dimensions D × N × LD \times N \times LD×N×The data of L is sent tothe space branchMulti-head Attention
      • where NNN isembeddingthe number ofLLL isembeddingthe length of ,DDD is the number of operations
      • The interaction is applied at N × LN \times LN×Learning and Abstracting Spatial Correlations on L as Embeddings
      • In the multi-head attention mechanism, the attention calculation of each head involves three linear transformations: query, key and value, which map the input sequence to different representation subspaces. In these representation subspaces, interactions occur between different heads to learn and capture different features in the model
      • Specifically, the interaction occurs in two phases:
        • Intra-Head Interactions ( Intra-head Interaction): Within each head, attention weights are computed by dot product of the query with the key. During this calculation, the query and the key interact to determine the relevance of the query to the key. This way, each head can selectively focus on or ignore information based on different parts of the input sequence
        • Inter-head interaction ( Inter-head Interaction): In multi-head attention, the attention weights and corresponding values ​​calculated by each head are combined by weighted summation. This merging can be seen as an interaction between different heads, which enables each head to synthesize different feature representations. Through the combination of multiple heads, the model can obtain a more comprehensive and rich context representation after integrating the information of multiple heads
    • In the depth branch , the input matrix is ​​first transposed to obtain a dimension of N × D × LN \times D \times LN×D×L data, which is then sent to the sameMulti-head Attention
      • While using the same attention mechanism, the interaction is at D × LD \times LD×Applied on L
      • In this way, deep dependencies are learned and abstracted
  • Finally, transform the results from the depth branch back to their original shapes and add them to the results from the spatial branch

2.3. Training Strategy

2.3.1. Encoder Pretrainig

Pre-training the encoder using medical images

This paper combines a variety of self-supervised learning methods for pre-training

  • Contrastive Embedding-Mixup (e-Mix)
    • e-Mixis a contrastive objective for unsupervised representation learning from text data
      • The contrastive objective is an objective function for self-supervised learning, often used to train representation learning models on unlabeled data
        • In the comparison objective, each sample is transformed into multiple views, which are usually called anchor(anchor) and positive(positive sample) between
        • They all come from the same original sample, but may have undergone different transformations (data augmentation) to produce different views
        • Then, for each anchor, we need to compare it with its corresponding positive samples so that their representations are closer in the representation space and more dispersed from the representations of other samples
      • e-MixIt is possible to do a weighted mix of a batch of raw input embeddings and weight them with different coefficients. The encoder is then trained to produce a mixture embedding vector that approximates the original input's embedding in proportion to the original input's mixture coefficients
  • Shuffled Embedding Prediction (ShED)
    • ShEDmixes a subset of embeddings and trains the encoder with a classifier to predict which embeddings are disturbed
    • ShEDAims at learning useful contextual representations from unlabeled text data
    • ShEDThe goal is based on the idea of ​​predicting the original order of shuffled words or token embeddings in a sentence
  • MAE
    • MAEmask a given part of the input embedding and train the model to reconstruct them

2.3.2. Training with Prompt

This article adapts to the new medical image data set SAM. This process is basically SAMthe same as that in . This article only considers 2 kinds of prompts: point& text( textthe original author of some codes seems to have not been disclosed yet)

  • point
    • This paper trains this prompt using a combination of random and iterative click sampling strategies
    • The specific process is as follows:
      • First use random sampling for initialization, then use an iterative sampling procedure to add somepoint
      • The iterative sampling strategy is similar to interactions with real users, since in practice each new pointt is placed in the wrong region of the network's predictions using the previous set of hits
      • Generating random samples and simulating iterative sampling
  • text
    • In SAM, the authors use CLIPimage embeddings of target object crops produced by , as embeddings CLIPclose to their corresponding textual descriptions or definitions in . However, since CLIPit is rarely trained on medical image datasets, it is difficult to associate organs/lesions on images with corresponding textual definitions
    • Medical image segmentation requires precise identification and labeling of different structures in the image, and inaccurate labels can lead to wrong diagnosis and treatment
    • To overcome this limitation, this paper proposes a different training strategy by generating free text containing target definitions as keywords in Chat-GPT, and then extracting text embeddings using as training CLIPcues
      • By using this approach, it is ensured that textual definitions are relevant to medical images, and that the model can accurately relate different structures in images to their corresponding textual definitions

Summarize

By using (a cost-effective fine-tuning technique), this paper achieves a significant improvement over the PEFT & Adaptionoriginal model and achieves state-of-the-art performance on 19 medical image segmentation tasks in 5 different image modalities.SAM

These results demonstrate the effectiveness of our method for adapting medical images and the potential to transfer powerful general segmentation models to medical applications.

Guess you like

Origin blog.csdn.net/HoraceYan/article/details/131761266