[Literature Reading] MCNet+: Mutual Consistency Learning for Semi-supervised Medical Image Segmentation

I have recently studied semi-supervised, and recorded the articles I have read. If there is any mistake in understanding, please point it out.
Article: Mutual Consistency Learning for Semi-supervised Medical Image Segmentation
Code: github


Motivation


The MC-Net+ model is motivated by the observation that deep models trained with limited annotations tend to output highly uncertain and misclassified predictions in ambiguous regions of medical image segmentation (e.g., cohesive edges or thin branches). Considering that deep models can produce segmentation results with pixel/voxel-level uncertainty, this method exploits this uncertainty to effectively utilize unlabeled data, aiming to further improve the performance of semi-supervised medical image segmentation. Somewhat similar to UAMT: use the uncertainty of network predictions to guide the consistency loss, and only calculate the consistency for high confidence regions. But this article is opposite to UAMT in terms of thinking. UAMT focuses on high-confidence areas, while MCNet+ focuses on low-confidence areas of unlabeled data, that is, areas that are difficult to segment.
insert image description here
Two key pieces of information can be seen from this experiment:
(1) Highly uncertain predictions mainly appear in some challenging areas. As the amount of training data increases, the network only optimizes the segmentation of these challenging areas. Results;
(2) As the amount of labeled data for training increases, the model outputs less ambiguous segmentation results.
Therefore, this paper hypothesizes that the generalization ability of a deep model should be highly correlated with the uncertainty of the model. These observations motivate this paper to explore model uncertainty to help the model generalize better on these hard-to-segment regions.


Method


insert image description here
MCNet+ uses an encoder and three decoders (the structure is the same, but the upsampling method is different), and the outputs of the three decoders are used as the pseudo-labels (Soft Pseudo Labels) of the other two decoders to calculate the consistency loss. In addition, if it is labeled data, the outputs of the three decoders are all calculated with the label to supervise the loss.


think


The strategy is actually very simple, so how does MCNet+ reflect its focus on low-confidence areas ?
When we train a network, we not only hope that the network can output correct results, but also hope that the network can have "confidence" in its own output, that is, the uncertainty of the network is as low as possible. A classic evaluation method of network uncertainty is to train a model multiple times on a data set at the same time, and compare the differences in their output results as a standard for model uncertainty evaluation. Then the difference in the output of the three decoders here can reflect the uncertainty of the model, and this uncertainty will only appear in the difficult-to-separate places such as the edge. MCNet+ allows the outputs of the three decoders to calculate the consistency loss with each other. The consistency loss here is mainly the difference in the low confidence area, so let the model pay attention to these hard-to-segment areas during training, so as to obtain a low-uncertainty model. .

Why is the lower the uncertainty of the model, the better the segmentation effect of the model ?
As mentioned above, the uncertainty of the model comes from the difference between the segmentation results of the model for the edge area during multiple trainings, so by constraining the uncertainty of the model, the model of multiple trainings is actually averaged to a certain extent, so that the final model The segmentation results for edge regions are more robust.

As mentioned in the article, MCNet+ not only uses constraints based on consistency, but also uses constraints based on entropy minimization. How are the two reflected ?
Consistency constraints are based on the consistency assumption that small perturbations should not affect the output of the network. MCNet+ uses the output of one decoder and the output of other decoders to calculate the consistency loss. In fact, the disturbance is the type of upsampling used in different decoders. The entropy minimization constraint is based on the clustering assumption, that is, clusters of the same category should be compact, in other words, low-entropy, that is to say, the model should be low-uncertainty, and MCNet+ constrains the output of different decoders to be consistent, which is actually a constraint The model lies in the uncertainty of the marginal regions.

A work CPS at the same time as this article also adopted a similar idea, but it used exactly the same model and different initialization. Why does this article use different upsampling for each decoder?
As mentioned in the article, the reason for using different upsampling for each decoder is to increase the diversity within the model. To put it bluntly, this is actually a way to create disturbances. The perturbation in CPS uses different model initializations, while the perturbation in MCNet+ uses different upsampling.

Guess you like

Origin blog.csdn.net/Fyw_Fyw_/article/details/129355749