【arXiv2303】Learning with Explicit Shape Priors for Medical Image Segmentation

Learning with Explicit Shape Priors for Medical Image Segmentation, aXiv2303

Interpretation: SPM: A plug-and-play shape prior module that can be easily embedded in any codec architecture, boosting points and significantly improving segmentation results! (qq.com)

Paper: https://arxiv.org/abs/2303.17967

Code: https://github.com/AlexYouXin/Explicit-Shape-Priors

Summary

UNet-based networks gradually dominate the field of medical image segmentation. However, convolutional neural networks (CNNs) face two limitations:

  • CNN has a limited receptive field and cannot model long-term dependencies or global relationships of organs or tissues.
  • The segmentation mask relies heavily on the training of the final segmentation head.

Existing methods do not address both limitations well. Therefore, this paper proposes a novel shape prior module (SPM), which can introduce shape priors to improve the segmentation performance of UNet-based models. The explicit shape prior consists of a global shape prior and a local shape prior.

  • A global shape prior with a coarse shape representation provides the network with the ability to model global context.
  • Local shape priors with finer-grained shape information can serve as additional guidance to improve segmentation performance, thus alleviating the heavy reliance on learnable prototypes in the segmentation head.

To evaluate the effectiveness of SPM, experiments are conducted on three challenging public datasets. SPM performance is excellent. Moreover, SPM shows outstanding generalization ability on classical CNNs and recent Transformer-based backbones, and can be used as a plug-and-play structure for different dataset segmentation tasks.

introduction

How to solve the problem of limited CNN receptive field? This paper sets out to explore the impact of shape priors ( shape priors) on segmentation performance.

In medical images, different organs or lesions usually have specific shapes and structures. These shape and structure information are very critical for segmentation models. Therefore, many previous works have tried to use shape priors to design segmentation models to obtain anatomical Better mask ( ) for shape information mask. The introduction of shape prior can help the segmentation model to better consider and utilize the shape information of the target object in the segmentation process , thereby improving the segmentation performance.

To this end, this paper focuses on three segmentation models with shape priors:

  • Graph-based model ( atlas-based models)
  • Statistical based model ( statistical-based models)
  • U-Net based model ( UNet-based models)

The paper believes that the generalization ability of the first two methods is poor, and the UNet-based model has better generalization performance than the first two, but because it tends to use implicit shape priors , this is different in organs of different shapes. It lacks good interpretability and generalization ability. In summary, this paper proposes a new shape prior module ( Shape Prior Module, SPM), which can explicitly introduce shape priors to boost the segmentation performance of UNet-based models. (See the paper for specific analysis)

The paper conducts experiments on three challenging public datasets to verify the effectiveness of SPM. SPM also exhibits strong generalization and can be used as a plug-and-play structure for different dataset segmentation tasks.

source:

SPM: A plug-and-play shape prior module that can be easily embedded into any codec architecture, boosting points and significantly improving segmentation results!

Implicit shape priors are usually implemented by including prior information in the model, such as a specific loss function or regularization term. These implicit shape priors are usually difficult to interpret because they are integrated into the model in some special ways, instead of directly considering the shape information of the target object. For example, in the UNet-based model, the shape prior information can be implicitly considered by using the Dice loss function to force the model to pay more attention to the contour information of the target object.

In contrast, explicit shape priors directly provide shape prior information as input to the model. For example, in this paper, the authors propose a new shape prior module that explicitly takes shape prior information as input and uses this information to guide the model to better segment target objects. Such explicit shape priors can be better interpreted and tuned because they directly consider the shape and structure information of the target object. 

method

A Unified Framework for Explicit Shape Models

Introducing a learnable repetitive shape prior S to U-shaped neural networks. Specifically, S is used as input to a network combined with images. The output of the network is the prediction mask and attention map generated by S. Then the channels of the attention map can provide rich shape information of the ground-truth label regions. The explicit shape prior model can be described as follows:

where F denotes the forward pass during inference and S denotes the continuous shape prior that constructs the mapping between image space I and label space L. Here, S is updated as the image GT pair changes during training. Once trained, the learnable shape prior is fixed, which can dynamically generate refined shape priors as the input patches change during inference. The fine shape prior serves as an attention map that localizes regions of interest and suppresses background regions. Furthermore, a small fraction of inaccurate ground truths does not significantly affect the learning of S, showing the robustness of this paradigm.

SPM( Shape Prior Module)

As shown in Figure 1, our model is a layered U-shaped network, which consists of a ResNet-like encoder, a Resblock-based decoder, and a shape prior module (SPM). SPM enhances the representational power of the network by introducing a learnable shape prior, imposing an anatomical shape constraint for each category. SPM is a plug-and-play module that can be flexibly plugged into other network structures to improve segmentation performance.

As shown in Figure 2, the input of SPM includes the original jump feature Fo and the original shape prior So, after "feature purification", the corresponding enhanced jump feature Fe and enhanced shape prior Se will be generated. Ultimately, with these enhanced features and priors, the model generates more accurate segmentation masks. Unlike DETR, SPM interacts with multi-scale features, not just those from the deepest layers of the encoder. Therefore, the hierarchically encoded features before skip connections will obtain more shape information after being processed by SPM. The augmented shape prior consists of two parts:

  • The global shape prior, generated by a self-update block.
  • The local shape prior, generated by a cross-update block.

Self-update block (SUB): modeling long-term dependencies

Aims to introduce an explicit shape prior capable of localizing object regions, the size of which is N × spatial dimension So. N represents the number of classes, and the spatial dimension is related to the patch size. To alleviate the shortcoming of limited receptive field, this work considers the long-range dependence within the shape prior. Specifically, a self-updating block (SUB) is proposed to model the relationship between classes and generate a global shape prior with interactions among N channels. Inspired by the self-attention mechanism, an affinity graph of self-attention Smap among N classes is constructed to describe the similarity and dependency relationship between each channel of the shape prior. Then use Smap to weight Vs, and then go through multi-layer perceptron MLP and layer normalization to get the global shape prior S_G.

Cross-update block (CUB): Models local shape priors.

Introducing an explicit shape prior brings global contextual information to SUB, but not precise shape and contour information. Because SUB lacks an inductive bias, it cannot model local visual structures and localize objects at various scales.

To address this limitation, the paper proposes a cross update block CUB. Inspired by the inherent locality and scale-invariant inductive bias of convolutional kernels, the convolution-based CUB injects an inductive bias into SPM for more accurate local shape information. Furthermore, based on the fact that convolutional features in the encoder have significant potential to localize discriminative regions, we make an interaction between the original skip feature Fo and the shape prior So.

Specifically,

  • First calculate the similarity map Cmap between the feature Fo and the shape prior So, which is used to evaluate the relationship between the C-channel feature map and the N-channel shape prior.
  • Then Cmap is applied to the transformed global shape prior to S_Grefine Fo, and the enhanced jump feature Fe is obtained, which has more accurate shape prior and rich global texture.
  • Local shape priors S_Lare generated from downsampled Fe, which has the property of modeling local visual structures (edges or corners).

In summary, the original shape prior is enhanced by introducing global and local features.

  • The global shape prior can model inter-class relations with a coarse shape prior based on self-attention blocks with sufficient global texture information.
  • The local shape prior reveals finer-grained shape information by introducing a convolution-based inductive bias.
  • In addition, SPM further enhances the original skipped features through the interaction with the global shape prior, which will facilitate the generation of features with discriminative shape representation and global context for more accurate prediction masks.

experiment

performance comparison 

 

 

 visual analysis

The figure above demonstrates the effect of skipping features on explicit shape priors. in:

  • Case (a) demonstrates explicit shape priors generated from different stages. Specifically, the shape prior consists of N channel attention maps, where N is the number of segmentation categories, and each row represents the shape prior from each stage. It can be found that with the top-down process, the shape prior presents more accurate activation maps for ground-truth label regions. In particular, regions that were falsely activated in the first phase will be suppressed in the second and third phases of SPM. In the visualized results, there is a phenomenon called inverse activation, which means that all regions except the GT region are activated.
  • Case (b) shows a typical example of the last stage and last channel of the shape prior. The authors claim that this phenomenon is caused by a global shape prior, which brings global context and rich texture information to the entire region, even including background regions. In essence, ROIs are simply localized by reverse attention, where ROIs are highlighted with sharp outlines.

 

Decompose the shape prior into two components from SUB and CUB, namely the global shape prior and the local shape prior: From Figure 7, it can be observed that thanks to the self-attention module, the global shape prior has a global feel Wild, containing context and textures. However, the structure of SUB lacks inductive bias to model local visual structure. The global shape prior is responsible for coarse localization of GT regions. While the local shape prior generated by CUB can provide finer shape information by introducing convolution kernels with local inductive bias.

 

 

 

 

key code

SUB and CUB

# https://github.com/AlexYouXin/Explicit-Shape-Priors/blob/main/networks/ACDC/SPM.py

class self_update_block(nn.Module):
    def __init__(self, config):
        super(self_update_block, self).__init__()
        num_layers = 2
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.n_patches, eps=1e-6)
        for _ in range(num_layers):
            layer = Block(config)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, refined_shape_prior):
        for layer_block in self.layer:
            refined_shape_prior = layer_block(refined_shape_prior)

        encoded = self.encoder_norm(refined_shape_prior)
        
        return encoded

class cross_update_block(nn.Module):
    def __init__(self, n_class):
        super(cross_update_block, self).__init__()
        self.n_class = n_class
        self.softmax = Softmax(dim=-1)

    def forward(self, refined_shape_prior, feature):
        class_feature = torch.matmul(feature.flatten(2), refined_shape_prior.flatten(2).transpose(-1, -2))
        # scale
        class_feature = class_feature / math.sqrt(self.n_class)
        class_feature = self.softmax(class_feature)

        class_feature = torch.einsum("ijk, iklhw->ijlhw", class_feature, refined_shape_prior)
        class_feature = feature + class_feature
        return class_feature


        
class Attention(nn.Module):
    def __init__(self, config):
        super(Attention, self).__init__()
        self.num_attention_heads = config.transformer.num_heads
        self.attention_head_size = int(config.n_patches / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.n_patches, config.n_patches)
        self.key = Linear(config.n_patches, config.n_patches)
        self.value = Linear(config.n_patches, config.n_patches)

        self.out = Linear(config.n_patches, config.n_patches)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)
        self.position_embeddings = nn.Parameter(torch.randn(1, self.num_attention_heads, config.n_classes, config.n_classes))

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # print(mixed_query_layer.shape)
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        
        attention_scores = attention_scores + self.position_embeddings                        # RPE
        
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        # weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output


class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.n_patches, config.hidden_size)
        self.fc2 = Linear(config.hidden_size, config.n_patches)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x       
        
        

class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()

        self.attention_norm = LayerNorm(config.n_patches, eps=1e-6)
        self.ffn_norm = LayerNorm(config.n_patches, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config)

    def forward(self, x):
        h = x                                              
        x = self.attention_norm(x)                         
        x = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x




class SPM(nn.Module):
    def __init__(self, config, in_channel, scale):
        super(SPM, self).__init__()
        self.scale = scale
        self.SUB = self_update_block(config)
        self.CUB  = cross_update_block(config.n_classes)
        self.resblock1 = DecoderResBlock(in_channel, in_channel)
        self.resblock2 = DecoderResBlock(in_channel, in_channel)
        self.resblock3 = DecoderResBlock(in_channel, config.n_classes)

        self.h = config.h
        self.w = config.w
        self.l = config.l
        self.dim = in_channel
        
        
    def forward(self, feature, refined_shape_prior):
        # print(refined_shape_prior.size())
        b, n_class, _ = refined_shape_prior.size()
        B = feature.size()[0]
        refined_shape_prior = self.SUB(refined_shape_prior)
        previous_class_center = refined_shape_prior
        refined_shape_prior = F.interpolate(refined_shape_prior.contiguous().view(b, n_class, self.h, self.w, self.l), scale_factor=self.scale, mode="trilinear")
        feature = self.resblock1(feature)
        feature = self.resblock2(feature)
        
        class_feature = self.CUB(refined_shape_prior, feature)
        
        # b * N * H/i * W/i * L/i
        refined_shape_prior = F.interpolate(self.resblock3(class_feature), scale_factor=(1.0 / self.scale[0], 1.0 / self.scale[1], 1.0 / self.scale[2]), mode="trilinear").flatten(2) + previous_class_center

        return class_feature, refined_shape_prior



class Conv3dbn(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )

        bn = nn.BatchNorm3d(out_channels)

        super(Conv3dbn, self).__init__(conv, bn)

class Conv3dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        bn = nn.BatchNorm3d(out_channels)

        super(Conv3dReLU, self).__init__(conv, bn, relu)

class DecoderResBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            use_batchnorm=True,
    ):
        super().__init__()
        self.conv1 = Conv3dReLU(
            in_channels,
            out_channels,
            kernel_size=1,
            padding=0,
            use_batchnorm=use_batchnorm,
        )
        self.conv2 = Conv3dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )

        self.conv3 = Conv3dbn(
            in_channels,
            out_channels,
            kernel_size=1,
            padding=0,
            use_batchnorm=use_batchnorm,
        )

        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, skip=None):

        feature_in = self.conv3(x)

        x = self.conv1(x)
        x = self.conv2(x)

        x = x + feature_in
        x = self.relu(x)
        # x = self.se_block(x)

        return x

Guess you like

Origin blog.csdn.net/m0_61899108/article/details/131160069