Medical Imaging Essay Study

Medical Image Segmentation

DoDNet:Learning to segment multi-organ and tumors from multiple partially labeled datasets(2021)

DoDNet, a single encoder-decoder network with a dynamic head, for partial labeling of multi-organ and tumor segmentation in abdominal CT scans. We also create a large-scale partially labeled dataset, MOTS, and conduct extensive experiments on it.

The results show that benefiting from task encoding and dynamic filter learning, DoDNet ​​not only achieves the best overall performance on seven organ and tumor segmentation tasks, but also has higher inference speed than other competitors. Additionally, we demonstrate the value of DoDNet ​​and MOTS datasets, and successfully transfer weights pretrained on MOTS to downstream tasks with limited annotations. It is also shown that a by-product of this work (i.e., a pretrained 3D network) is beneficial for other small-shot 3D medical image segmentation tasks.

Questions raised

Partially Labeled Medical Image Segmentation Multi-organ and tumor segmentation are ubiquitous difficulties in medical image analysis, especially in the absence of large-scale fully labeled datasets. While several partially labeled datasets are available, each is dedicated to the segmentation of a particular organ and/or tumor.

Therefore, segmentation models are usually trained on a partially labeled dataset and thus can only segment a specific organ and tumor, such as liver and hepatoma, kidney and kidney tumor. However, training multiple networks leads to waste of computing resources and poor scalability.

previous studies

To address this issue, several attempts have been made to explore multiple partially labeled datasets in a more efficient manner.

Chen et al. collected multiple partially labeled datasets from different medical domains and jointly trained on them a heterogeneous 3D network with a specially designed task-sharing encoder and task-specific tasks for 8 segmentation tasks. decoder.

Huang et al. proposed to jointly train a pair of weight-averaged models on the model for unified multi-organ segmentation of few-organ datasets.

Zhou et al. first approximated anatomical priors on abdominal organ sizes on a fully labeled dataset, and then normalized the organ size distributions on several partially labeled datasets.

Fang et al. proposed Target Adaptive Loss (TAL) for segmentation networks trained on multiple partially labeled datasets with unlabeled voxels as the background.

Shi et al. merge unlabeled organs with background and impose exclusive constraints on each voxel (i.e., each voxel belongs to an organ or background within a voxel), in a fully labeled dataset and several partially labeled Jointly Learning Segmentation Models on Datasets.

To learn multi-class segmentation from a single-class dataset, Dmitriev et al. take the segmentation task as a prior and incorporate it into the intermediate activation signal.

Innovation

  1. The problem of partially labeling (part of the labeled dataset) as a multi-class segmentation task, with unlabeled organs as the background, can be misleading, because in this dataset the unlabeled organs are indeed another prospects for the task. To address this issue, we formulate the partial labeling problem as a one-class segmentation task, aiming to segment each organ separately;

  2. Most of those previous methods adopt a multi-head structure, consisting of a shared backbone network and multiple segmentation heads for different tasks. Each header is either a decoder [3] or the last segmentation layer [9, 30]. In contrast, the proposed DoDNet ​​is a single-head network, where the head is flexible and dynamic;

  3. Our DoDNet ​​uses a dynamic segmentation head to solve the problem of partial labeling, instead of embedding the task before the encoder and decoder;

  4. Existing methods mostly focus on multi-organ segmentation, while our DoDNet ​​segments both organs and tumors, which is more challenging.

network design

insert image description here

There are three approaches to perform the part label segmentation task of m.

  • (a) Multi-network: train m networks on m partially labeled subsets respectively;

  • (b) Multi-head network: train a network consisting of a shared encoder and m task-specific decoders (heads), each performing a partially labeled segmentation task;

  • (c) The proposed DoDNet: it has an encoder, a task encoding module, a dynamic filter generation module and a dynamic segmentation head. The kernels in the dynamic head are conditioned on the input image and the assigned task.

Code

Learning Calibrated Medical Image Segmentation via Multi-rater Agreement Modeling(2021)

In medical image analysis, multiple annotations are often collected, each from a different clinical expert or rater, in hopes of mitigating possible diagnostic errors.

Meanwhile, from the perspective of computer vision practitioners, it is a common practice to employ ground-truth labels obtained through majority voting or simple preferred raters.
However, this process often ignores the rich information entrenched in agreement or inconsistency in the original multi-rater annotations. To address this issue, we propose to explicitly model a multi-rater (dis-) protocol, called MRNet, which has two main contributions.

  1. First, an expert-aware inference module or EIM is designed to embed individual raters' expertise as prior knowledge to form high-level semantic features.
  2. Second, our method is able to reconstruct multi-rater grading from coarse predictions, and further leverages multi-rater (non-)agreement cues to improve segmentation performance. To the best of our knowledge, our work is the first to produce calibrated predictions for medical i at different levels of expertise
  • To put it bluntly, first roughly detect where the target area is, make a rough judgment, and then zoom in on this area for detection

insert image description here

Model description

insert image description here
insert image description here

MRNet framework:

  1. (a) Overview of the processing pipeline and continue to zoom in on the individual modules of the diagram

  2. (b) Professional Awareness Inference Module (EIM)

  3. (c) Multi-rater agreement modeling (MAM) consists of a multi-rater reconstruction module (MRM) and a multi-rater perception module (MPM).

MR Image Super-Resolution with Squeeze and Excitation Reasoning Attention Network(2021)

  • MR images generally share some common visual features: repetitive patterns, relatively simple structures, and less informative backgrounds. Mr images usually contain large background regions with much less information than target structure regions (redundant information)

  • To address these issues, we propose squeeze and excitation inference attention network (SERAN) for accurate MR image SR.

    • We input squeeze attention from global spatial information to obtain global descriptors. This global descriptor enhances the ability of the network to focus on more informative regions and structures in MR images.
    • It allows the model to focus on more informative regions and structures in the Mr image. We further establish the relationship between global descriptors, that is, establish the relationship between primitives, and apply a graph convolutional network (GCN) for inference to obtain inference attention on the primitive relationship.
    • The global descriptor is further refined by learned attention. To fully utilize the aggregated information, we adaptively recalibrate the feature responses using the learned adaptive attention vector. These attention vectors select a subset of global descriptors to complement each spatial location for accurate detail and texture reconstruction.

insert image description here

A brief illustration of our extrusion and excitatory attention mechanisms.

Similar to [3], global features are first collected by bilinear pooling, and then distributed to each spatial location by considering corresponding local features. However, we augment the global features with Primal Relational Reasoning (PRR)

insert image description here
insert image description here
It was formulated for two reasons.

First, using residual learning directly here will make the training process numerically unstable.

Second, residual connections allow us to plug SEAB into any pretrained network without affecting its initial behavior too much.

With the use of SEAB, subsequent convolutional layers can perceive the entire space even with a limited receptive field size.

SEAB allows the network to focus on more informative visual features and achieve better SR reconstruction quality from MR images.

The process of GCN convolution extraction attention

insert image description here

Image Reconstruction with Adaptive Attention

After collecting the global feature descriptors, we want to distribute them to each position of the original features. This will help us better exploit the complex relationship with the computed second-order statistics and compensate for lost information for better MR image reconstruction.

insert image description here
We can see that each location of the original feature has its specific need for a global descriptor. We learn attention vectors di d_i according todi, adaptively assigning a global descriptor V at each location. This means that each location can adaptively choose a complementary visual proto-semantic.

  • final effect

insert image description here

Preservational Learning Improves Self-supervised Medical Image Models by Reconstructing Diverse Contexts

Preserving maximum information is one of the principles in designing self-supervised learning methods.

To achieve this goal, contrastive learning employs implicit contrastive image pairs. The goal of contrastive learning is to learn invariant representations by contrasting pairs of medical images, which can be seen as an implicit way to preserve maximum information.

However, we believe that simply using contrastive estimates for preservation is not entirely optimal. We argue that explicitly retaining more information besides the contrastive loss is still beneficial and complementary.

From this perspective, we introduce preservation learning to reconstruct different image environments to preserve more information in the learned representations.

一个直观的解决方案是使用学习到的表示来重建原始输入,以便这些表示可以保存与输入密切相关的信息。然而,我们发现直接添加一个普通的重建分支来恢复原始输入并不会显著改善学习到的表示。为了解决这个问题,我们引入了保留性对比表示学习,利用从对比损失中学习到的表示来重建不同的上下文。

Combined with contrastive loss, we propose Conservative Contrastive Representation Learning (PCRL) for learning self-supervised medical representations.
PCRL provides very competitive results under the pre-training-fine-tuning protocol, substantially outperforming self-supervised and supervised counterparts in 5 classification/segmentation tasks.

The contributions of this paper can be summarized in three aspects:

  1. Conservative contrastive representation learning is introduced to encode more information into representations learned from contrastive losses by reconstructing different contexts.
  2. To recover different images, we propose two modules: transition-conditional attention and cross-model mixture to build a triple encoder, single decoder self-supervised learning architecture.
  3. Extensive experiments and analyzes show that the proposed PCRL has clear advantages in 5 classification/segmentation tasks, outperforming both self-supervised and supervised tasks in significant aspects.

previous studies

  1. By image rotation, object color, number of objects [25] and applied transformation function [30]. Contrastive estimation-based methods also exploit pretext tasks to learn invariant representations by contrasting image pairs. Recently, there have been some studies trying to remove negative pairs in contrastive learning. In contrast, our approach follows a different principle, even if the representations fully describe their origin (i.e. the corresponding input image).

  2. Self-supervised learning in medical image analysis. Prior to contrastive learning, solving jigsaw puzzles [54, 53, 35] and reconstructing corrupted images [9, 52] were the two main subjects of pretext-based methods in medical images. Besides, Xie et al. [44] also introduced a triplet loss for self-supervised learning in nuclear images. Haghigi et al. [19] improved upon [52] by appending a classification branch to classify high-level features into different anatomical patterns. For contrastive learning, Zhou et al. [51] applied a contrastive loss to 2D radiographs. Similar ideas also emerge in few-shot [49] and semi-supervised learning [50]. [34] proposed 3D contrastive predictive coding using 3D medical images. There are two [16,8] works that are most relevant to us. [16] showed that the reconstruction process of partial images has a similar effect as using a contrastive loss. [8] et al introduce a denoising autoencoder to capture a latent spatial representation. However, both methods fail to improve contrastive learning through context reconstruction, while our method succeeds in this aspect

insert image description here
We try to incorporate different image reconstructions as pretext tasks into contrastive learning. The main motivation is to encode more information into the learned representation.

Specifically, we introduce transition conditional attention and cross-model mixture to enrich the information carried by the representation. The first module embeds a vector of transition indicators (vec(T)) into the high-level feature map.

Based on the embedding vectors, the network needs to dynamically reconstruct different image objects while the input is fixed.

  • A hybrid encoder is generated by mixing the feature maps of the normal encoder and the momentum encoder, and the hybrid encoder is asked to reconstruct the hybrid image object.

We show that both modules can help encode more information and produce stronger representations than using contrastive learning alone.

GO represents a global operation that converts feature maps into feature vectors. The blue eigenvectors are from the momentum encoder. vec(T) represents the index vector of T, which contains a set of transformation functions. Each component in vec(T) is 1 or 0, indicating whether the corresponding transformation was applied or n

Model framework

insert image description here
PCRL employs a U-Net-like architecture to learn representations. For encoder and decoder, we plot their feature maps for better demonstration.

The hybrid encoder does not accept the input image because it consists of a hybrid feature map from the normal encoder and the momentum encoder. {C., F., R., I., O., B.} are the abbreviations for Random Crop, Random Flip, Random Rotation, Inner Paint, Outer Paint, and Gaussian Blur, respectively.

NCE is an acronym for Noise Contrastive Estimation. GO stands for global operation, which includes global average pooling layers and fully connected layers. vec( ) represents the indicator vector. T{o, m, h} ( ) denote a set of transfer functions for different encoders. ⊙ denotes channel-level multiplication. For simplicity, we do not draw skip connections.

  • Attention module under our transition condition.

insert image description here

F. and R. represent flip and rotation, respectively. {x, y, z} represent coordinate axes. {0, 90°, 180°, 270°} represent degrees of rotation. vec(T) represents the pointer vector of T, and its subscript is omitted for simplicity. ⊗ denotes an external product. ⊙ denotes channel-level multiplication. Note that the diagram above demonstrates the implementation when every input is 3D. For 2D input, there is no F(z) in the indicator vector. For 2D and 3D input, rotation is only applied to the xy plane.

  • Learning representations using contrastive losses can greatly improve medical image analysis by reconstructing different contexts. Our approach shows positive results for self-supervised learning on a variety of medical tasks and datasets.

What nb said is not working, let me see if your code works

import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.nn.functional as F
import torch
from segmentation_models_pytorch.base import modules as md
import numpy as np
from torchvision.models.resnet import ResNet
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import Bottleneck
from pretrainedmodels.models.torchvision_models import pretrained_settings
from segmentation_models_pytorch.base.initialization import initialize_decoder, initialize_head
from segmentation_models_pytorch.base import SegmentationHead
from segmentation_models_pytorch.encoders._base import EncoderMixin
import copy
import random


def initialize_decoder(module):
    for m in module.modules():

        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)


def initialize_head(module):
    for m in module.modules():
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)


class CenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        conv1 = md.Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)


class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            skip_channels,
            out_channels,
            use_batchnorm=True,
            attention_type=None,
    ):
        super().__init__()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class ShuffleUnetDecoder(nn.Module):
    def __init__(
            self,
            # decoder,
            encoder_channels=512,
            n_class=3,
            decoder_channels=(256, 128, 64, 32, 16),
            n_blocks=5,
            use_batchnorm=True,
            center=False,
            attention_type=None

    ):
        super().__init__()
        # self.decoder = decoder
        # self.segmentation_head = segmentation_head
        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        encoder_channels = encoder_channels[1:]  # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[::-1]  # reverse channels to start from head of encoder

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels
        # self.conv = nn.Conv2d(1024, 512, kernel_size=3, padding=1, stride=1)
        if center:
            self.center = CenterBlock(
                head_channels, head_channels, use_batchnorm=use_batchnorm
            )
        else:
            self.center = nn.Identity()
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)
        initialize_decoder(self.blocks)
        # self.segmentation_head = SegmentationHead(16, 3)
        # initialize_head(self.segmentation_head)
        # self.segmentation_head = segmentation_head
        #
        # # combine decoder keyword arguments

    def forward(self, features1, features2, alpha, aug_tensor1, aug_tensor2, mixup=False):
        # x = self.decoder(*features)
        # return self.segmentation_head(x)
        # def forward(self, features1, features2):
        #
        features1 = features1[1:]  # remove first skip with same spatial resolution
        features1 = features1[::-1]  # reverse channels to start from head of encoder
        features2 = features2[1:]
        features2 = features2[::-1]
        head1 = features1[0]
        skips1 = features1[1:]
        head2 = features2[0]
        skips2 = features2[1:]
        x1 = self.center(head1)
        x2 = self.center(head2)
        if not mixup:
            x1 = x1 * aug_tensor1
            x2 = x2 * aug_tensor2
        x3 = x1.clone()
        x1 = alpha * x1 + (1 - alpha) * x2
        for i, decoder_block in enumerate(self.blocks):
            # print(i, x1.shape, skips1[i].shape, x2.shape, skips2[i].shape)
            skip1 = skips1[i] if i < len(skips1) else None
            #skip1_shuffle = self.decoder_shuffle(skip1, shuffle_num + i + 1) if i < len(skips1) else None
            x3 = decoder_block(x3, skip1)

            # x1 = decoder_block(x1, skip1)
            skip2 = skips2[i] if i < len(skips2) else None
            skip = alpha * skip1 + (1 - alpha) * skip2 if i < len(skips2) else None
            # skip = self.decoder_shuffle(skip, shuffle_num + i + 1) if i < len(skips2) else None
            # x2 = decoder_block(x2, skip2)
            x1 = decoder_block(x1, skip)

        # x1 = self.segmentation_head(x1)
        return x1, x3

    def decoder_shuffle(self, x, shuffle_num):
        w = x.shape[2]
        h = x.shape[3]
        shuffle_col_index = torch.randperm(w)[:shuffle_num].cuda()
        shuffle_row_index = torch.randperm(h)[:shuffle_num].cuda()
        col_index = shuffle_col_index[torch.randperm(shuffle_col_index.shape[0])]
        row_index = shuffle_row_index[torch.randperm(shuffle_row_index.shape[0])]
        # print(col_index, row_index, shuffle_row_index, shuffle_col_index)
        # print(shuffle_row_index, x.shape, x[:, :, shuffle_row_index].shape)
        x = x.index_copy(2, col_index, x.index_select(2, shuffle_col_index))
        x = x.index_copy(3, row_index, x.index_select(3, shuffle_row_index))
        return x


class PCRLModel(nn.Module):
    def __init__(self, n_class=3, low_dim=128, student=False):
        super(PCRLModel, self).__init__()
        self.model = smp.Unet('resnet18', in_channels=3, classes=n_class, encoder_weights=None)
        self.model.decoder = ShuffleUnetDecoder(self.model.encoder.out_channels)
        # self.segmentation_head = self.unet.segmentation_head
        # self.model = net
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512, low_dim)
        self.relu = nn.ReLU(inplace=True)
        self.student = student
        self.fc2 = nn.Linear(low_dim, low_dim)
        self.aug_fc1 = nn.Linear(6, 256)
        self.aug_fc2 = nn.Linear(256, 512)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, features_ema=None, alpha=None, aug_tensor1=None, aug_tensor2=None, mixup=False):
        b = x.shape[0]
        features = self.model.encoder(x)
        feature = self.avg_pool(features[-1])
        feature = feature.view(b, -1)
        feature = self.fc1(feature)
        feature = self.relu(feature)
        feature = self.fc2(feature)
        if self.student:
            if not mixup:
                aug_tensor1 = self.aug_fc1(aug_tensor1)
                aug_tensor1 = self.relu(aug_tensor1)
                aug_tensor1 = self.aug_fc2(aug_tensor1)
                aug_tensor2 = self.aug_fc1(aug_tensor2)
                aug_tensor2 = self.relu(aug_tensor2)
                aug_tensor2 = self.aug_fc2(aug_tensor2)
                aug_tensor1 = self.sigmoid(aug_tensor1)
                aug_tensor2 = self.sigmoid(aug_tensor2)
                aug_tensor1 = aug_tensor1.view(b, 512, 1, 1)
                aug_tensor2 = aug_tensor2.view(b, 512, 1, 1)
                # print(aug_tensor2.shape)
            decoder_output_alpha, decoder_output = self.model.decoder(features, features_ema, alpha, aug_tensor1,
                                                                      aug_tensor2, mixup)
            masks_alpha = self.model.segmentation_head(decoder_output_alpha)
            masks = self.model.segmentation_head(decoder_output)
            return feature, masks_alpha, masks
        return feature, features

Source: Collection

Guess you like

Origin blog.csdn.net/RandyHan/article/details/129739342