Paper: https://arxiv.org/abs/2307.03942 , Miccai 2023
In fact, this article is entirely about the network structure optimization of VLiT, and the changes are not too big. I thought it would be better to post an article in this relatively new direction. I won’t introduce much background here, you can directly refer to my previous blog post VLiT. I feel that the name is quite interesting, Ariadne's Thread, the name comes from ancient Greek mythology, which tells the story of Theseus getting out of the maze with the help of Ariadne's golden thread. The subsequent extended learning is also very interesting.
Summary
Segmentation of lung infection regions is critical for quantifying the severity of lung diseases such as lung infections. Existing medical image segmentation methods are almost all image-based single-mode methods. However, these image-only methods tend to produce inaccurate results unless trained with large amounts of annotated data. To overcome this challenge, we propose a language-driven segmentation method that uses textual cues to improve segmentation results . Experiments on the QaTa-COV19 dataset show that our method improves Dice score by at least 6.09% compared to unimodal methods. Furthermore, our extended study reveals the flexibility of multimodal approaches with regard to the granularity of textual information and shows that multimodal approaches have significant advantages over image-only approaches in terms of the size of the required training data .
background
Direct reference to LViT: Language and Vision Transformer in Medical Image Segmentation_Scabbards_'s Blog-CSDN Blog ,
This article is a follow-up work of LViT, optimizing the model structure
main contribution
1) We propose a language-driven segmentation method for segmenting infection regions from lung x-ray images.
2) The guide decoder designed by this method can adaptively propagate sufficient semantic information of text prompts into pixel-level visual features, which promotes the consistency of the two modes.
3) Cleaned up the errors contained in the text comments of QaTa-COV19 [17] and contacted the LViT author to release a new version.
4) The extended study reveals the impact of textual cue information granularity on the segmentation performance of our method, and demonstrates the significant advantage of multimodal methods over image-only methods in terms of required training data size.
model structure
Our proposed method adopts a modular design, where the model mainly consists of an image encoder, a text encoder and several guidedecoders. GuideDecoder is used to adaptively propagate semantic information from text features to visual features, and output decoded visual features
Compared with the earlier fusion of LViT, our proposed modular design method is more flexible. For example, when our method is applied to brain MRI images, due to the modular design, we can first load the pre-trained weights trained on the corresponding data to separate the visual and textual encoders, and then only need to train the GuideDecoders.
Image Encoder
ConvNeXt-Tiny
enter
The four stages are
C is the feature dimension, H and W are the height and width of the original image
Text Encoder
CXR-BERT
enter
get text features
C is the feature dimension and L is the length of the text prompt.
GuideDecoder
Enter: ,
output:
Before performing multimodal interactions, GuideDecoder first processes the input text features and visual features.
step1
Text
enter:
output:
The input text features first pass through a projection module (i.e., Project in the figure), which aligns the dimensions of text tokens with those of image tokens and reduces the number of text tokens.
WT is a learnable matrix
Conv( ) is a 1 X 1 convolution
( ) is the ReLU activation function
step2
Image
enter:
output:
, and the residual link product
After adding positional encoding, use self-attention to enhance the visual information in the image to obtain visual features
MHSA(·) is the Multi-Head Self-Attention layer
LN(·) 是 Layer Normalization
Step3
Enter: ,
Output: multimodal features
Adopt multi-head cross-attention layers to propagate fine-grained semantic information into evolved image features
MHCA(·) is multi-head cross-attention
α is a learnable parameter that controls the weight of the remaining connections
step4
enter:
output:
Reshape and Upsample
Step5
Input: , , where fs are the low-level visual features obtained from the visual encoder via skip connections
output:
Processing through convolutional layers and ReLU activation functions
where [ , ] represent the join operation on the channel dimension.
experiment
data set
QaTa-COV19
We found some glaring errors (e.g., misspelled words, grammatical errors, ambiguities) in the extended text annotations. We have fixed these identified errors and contacted the LViT authors to release a new version of the dataset.
It consists of 9258 COVID-19 chest radiographs with pixel-level manual annotations of infected lung regions, of which 7145 are in the training set and 2113 are in the testing set.
experiment settings
data processing
80% and 20%. Therefore, the training set has a total of 5716 samples, the validation set has 1429 samples, and the test set has 2113 samples. All images are cropped to 224 × 224, and the data is augmented using random scaling with 10% probability.
hardware
We use PyTorch Lightning as the final training and inference wrapper. All methods are trained on an NVIDIA Tesla V100 SXM3 32GB VRAM GPU.
training details
We use Dice loss + Cross-entropy loss as loss function, and use AdamW optimization with batch size 32 to train the network. We utilize a cosine annealing learning rate strategy with an initial learning rate of 3e-4 and a minimum learning rate of 1e-6.
Evaluation index
Accuracy, Dice Loss and Jaccard coefficients. Both the Dice coefficient and the Jaccard coefficient calculate the intersection area over the joint region of a given prediction mask and ground truth, where the Dice coefficient can better reflect the segmentation performance of small objects.
Both the Dice coefficient and the Jaccard coefficient calculate the intersection area on the joint area of the given predicted Mask and Ground Truth, and the Dice coefficient can better reflect the segmentation performance of small objects.
(So I personally feel that Jaccard is not necessary)
Experimental results
Qualitative experimental results are shown in Figure 2. Image-only unimodal approaches are prone to some over-segmentation
The multimodal method refers to segmenting the specific location of the infected area through text prompts to make the segmentation results more accurate.
Ablation experiment
As can be seen from Table 2, as the number of guidedecoders used in the model increases, the segmentation performance of the model also improves. These results can demonstrate the effectiveness of the lead-decoder.
Extended learning
Effect of Text Hints with Different Granularity on Segmentation Performance
Expand each sample into three-part text annotations containing location information at different granularities, as shown in Fig.
Diagram showing text at different granularities and segmentation performance
The results in the table show that our method is driven by the granularity of the location information contained in the text prompts.
Our proposed method achieves better segmentation performance when given text prompts containing more detailed location information .
Meanwhile, we observe that the performance is almost the same when using two types of text cues, i.e. Stage3 alone and Stage1+Stage2+Stage3 . This means that the most detailed location information in text cues plays the most important role in improving segmentation performance. But this does not mean that other granular position information in text prompts does not contribute to the improvement of segmentation performance. Even though the input text hint contains only the coarsest position information (Stage1+) in the Stage2 item in Table 3, our proposed method achieves 1.43% higher Dice score than the method without text hint
Effect of training data size on segmentation performance
Our proposed method shows highly competitive performance even with reduced amount of training data.
Using only a quarter of the training data, our proposed method outperforms the Dice score of UNet++ by 2.69%, the best performing unimodal model trained on the full dataset. This provides ample evidence for the superiority of multimodal approaches, and appropriate text cues can significantly improve segmentation performance.
We observe that when the training data is reduced to 10%, our method starts to show worse performance than UNet++, which is trained with all available data. Similar experiments can also be found in the LViT paper. Therefore, it can be argued that multimodal methods require only a small amount of data (less than 15% in our method) to achieve comparable performance to unimodal methods .