【arXiv2303】Learning with Explicit Shape Priors for Medical Image Segmentation

Learning with Explicit Shape Priors for Medical Image Segmentation, aXiv2303

解读:SPM: 一种即插即用的形状先验模块,可轻松嵌入任意编解码架构,助力涨点并显著改善分割效果! (qq.com)

论文:https://arxiv.org/abs/2303.17967

代码:https://github.com/AlexYouXin/Explicit-Shape-Priors

摘要

基于UNet的网络在医学图像分割领域逐步占据主导地位。然而,卷积神经网络(CNNs)面临两个限制:

  • CNN感受野有限,无法对器官或组织的长期依赖或全局关系进行建模。
  • 分割掩码很大程度上依赖于最终分割头的训练。

现有的方法不能很好地同时解决这两个限制。因此,本文提出了一种新的形状先验模块(SPM),它可以引入形状先验来提高基于UNet的模型的分割性能。显式形状先验由全局形状先验和局部形状先验组成。

  • 具有粗略形状表示的全局形状先验为网络提供了对全局上下文建模的能力。
  • 局部形状先验具有更精细的形状信息,可以作为提高分割性能的额外指导,从而缓解对分割头中可学习原型的严重依赖。

为了评估SPM的有效性,在三个具有挑战性的公共数据集上进行了实验。SPM性能优异。此外,SPM在经典的细胞神经网络和最近的基于Transformer的主干上表现出了突出的泛化能力,可以作为不同数据集分割任务的即插即用结构。

引言

如何解决CNN感受野有限的问题呢?本文开始探索形状先验(shape priors)对分割性能的影响。

在医学图像中,不同的器官或病灶通常具有特定的形状和结构,这些形状和结构信息对于分割模型来说非常关键,因此先前的许多工作尝试利用形状先验来设计分割模型,以获得具有解剖形状信息的更好掩模(mask)。就是引入形状先验可以帮助分割模型在分割过程中更好地考虑和利用目标物体的形状信息,从而提高分割性能。

为此,本文集中探讨了三种带有形状先验的分割模型:

  • 基于图谱的模型(atlas-based models)
  • 基于统计的模型(statistical-based models)
  • 基于U-Net的模型(UNet-based models)

论文认为,前两种方法的泛化能力较差,而 UNet-based 模型由于相比于前两者泛化性能要好,但由于它是倾向于使用隐式形状先验,这在不同形状的器官上缺乏良好的可解释性和泛化能力。综上所述,本文提出了一种新的形状先验模块(Shape Prior Module, SPM),它可以显示地引入形状先验,以促进 UNet-based 模型的分割性能。(具体分析见论文)

论文在三个具有挑战性的公共数据集上进行实验,验证了SPM的有效性。SPM也表现出很强的泛化性,可作为不同数据集分割任务的即插即用结构。

来源:

SPM: 一种即插即用的形状先验模块,可轻松嵌入任意编解码架构,助力涨点并显著改善分割效果!

隐式形状先验通常是通过在模型中加入先验信息,例如特定的损失函数或正则化项来实现的。这些隐式的形状先验通常难以解释,因为它们是通过一些特殊的方式集成到模型中的,而不是直接考虑目标物体的形状信息。例如,在基于 UNet 的模型中,可以通过使用 Dice 损失函数来强制模型更加注重目标物体的轮廓信息,从而隐式地考虑了形状先验信息。

相反,显式形状先验则直接将形状先验信息作为输入提供给模型。例如,在本文中,作者提出了一个新的形状先验模块,它明确地将形状先验信息作为输入,并利用这些信息来引导模型更好地分割目标物体。这种显式的形状先验可以更好地解释和调整,因为它们直接考虑了目标物体的形状和结构信息。 

方法

显式形状模型的统一框架

将可学习的重复形状先验S引入U形神经网络。具体地,S被用作与图像组合的网络的输入。网络的输出是由S生成的预测掩码和注意力图。然后注意力图的通道可以提供真实标签区域的丰富形状信息。显式形状先验模型可以描述如下:

其中,F表示推理期间的前向传播,S表示构造图像空间I和标签空间L之间的映射的连续形状先验。这里,S在训练过程中随着图像GT对的变化而更新。一旦训练完成,可学习的形状先验就被固定,这可以随着推理过程中输入补丁的变化而动态地生成精细的形状先验。精细形状先验作为注意力图,可以定位感兴趣的区域,并抑制背景区域。此外,一小部分不准确的基本事实不会显著影响S的学习,显示了该范式的稳健性。

SPM( Shape Prior Module)

图1所示,本文模型是一个分层的U形网络,它由类ResNet编码器、基于Resblock的解码器和形状先验模块(SPM)组成。SPM通过引入可学习形状先验,为每个类别施加解剖形状约束来增强网络的表示能力。SPM是一个即插即用模块,可以灵活地插入其他网络结构,以提高分割性能。

图2所示,SPM的输入包括原始跳跃特征Fo和原始形状先验So,经过“特征提纯”后会生成对应的增强跳跃特征Fe和增强形状先验Se 。最终,通过这些增强后的特征和先验,模型会生成更加精准的分割掩膜。与DETR不同,SPM会与多尺度特征进行交互,而不仅仅是来自编码器最深层的特征。因此,在跳跃连接之前的分层编码特征在经过SPM处理后将获得更多的形状信息。增强形状先验由两个部分组成:

  • 全局形状先验,由自更新块(self-update)生成。
  • 局部形状先验,由交叉更新块(cross-update block)生成。

Self-update block (SUB):建模长期依赖关系

旨在引入能够定位目标区域的显式形状先验的基础上,形状先验的大小So是N×空间维度。N表示类的数量,空间维度与补丁大小有关。为了缓解感受野有限的缺点,本工作考虑了形状先验内的长程依赖性。具体而言,提出了自更新块(SUB)来对类之间的关系进行建模,并生成具有N个通道之间相互作用的全局形状先验。受自注意机制的启发,构建了N类之间的自注意Smap的亲和图,以描述形状先验的每个通道之间的相似性和依赖性关系。再采用Smap加权Vs,随后经过多层感知机MLP和层归一化处理,得到全局形状先验S_G

Cross-update block (CUB):对局部形状先验进行建模。

引入显式形状先验給SUB带来了全局上下文信息,但不具有精确的形状和轮廓信息。因为SUB缺乏归纳偏置,无法建模局部视觉结构和定位各种不同尺度的对象。

为了解决这个限制,论文提出交叉更新块CUB。受到卷积核固有的局部性和尺度不变性的归纳偏置的启发,基于卷积的 CUB 为 SPM 注入归纳偏置,以获得更精确的局部形状信息。此外,基于编码器中卷积特征具有定位区分性区域的显著潜力的事实,论文在原始跳跃特征Fo和形状先验So之间进行交互。

具体来说,

  • 先计算特征Fo和形状先验So之间的相似度图Cmap,用于评估C通道特征图和N通道形状先验之间的关系。
  • 再将Cmap作用于变换后的全局形状先验S_G来细化Fo,得到增强的跳接特征Fe,其具有更准确的形状先验和丰富的全局纹理。
  • 局部形状先验S_L由下采样的Fe生成,其具有对局部视觉结构(边缘或拐角)建模的特性。

综上所述,原始形状先验通过引入全局和局部特征进行增强。

  • 全局形状先验可以对类间关系进行建模,类间关系具有基于自注意块的具有足够全局纹理信息的粗糙形状先验。
  • 局部形状先验通过引入基于卷积的归纳偏置来显示更精细的形状信息。
  • 此外,SPM通过与全局形状先验的交互,进一步增强原始跳接特征,这将促进生成具有判别性形状表示和全局上下文的特征,从而获得更精确的预测掩码。

实验

性能比较 

 

 

 可视化分析

上图展示了跳跃特征对明确形状先验的影响。其中:

  • 案例(a)展示了从不同阶段生成的明确形状先验。具体来说,形状先验由 N 个通道注意力图组成,其中 N 是分割类别的数量,每行表示来自每个阶段的形状先验。可以发现,随着自上而下的过程,形状先验对于真实标签区域呈现出更准确的激活图。特别是,在第一阶段中错误激活的区域将在 SPM 的第二和第三阶段中被抑制。在可视化结果中,存在一种称为反向激活的现象,这意味着除了 GT 区域之外的所有区域都被激活。
  • 案例(b)中则展示了形状先验的最后一个阶段和最后一个通道的典型例子。作者声称,这种现象是由全局形状先验造成的,它为整个区域带来了全局上下文和丰富的纹理信息,甚至包括背景区域。实质上,通过反向注意力简单地定位 ROI,其中 ROI 用清晰的轮廓突出显示。

 

将形状先验分解为来自 SUB 和 CUB 的两个组成部分,即全局形状先验和局部形状先验:从图7可以观察到,得益于自注意力模块,全局形状先验具有全局的感受野,包含上下文和纹理。然而,SUB 的结构缺乏归纳偏差来模拟局部视觉结构。全局形状先验负责对 GT 区域进行粗定位。而由 CUB 生成的局部形状先验可以通过引入卷积核提供更精细的形状信息,这些卷积核具有局部归纳偏差。

 

 

 

 

关键代码

SUB和CUB

# https://github.com/AlexYouXin/Explicit-Shape-Priors/blob/main/networks/ACDC/SPM.py

class self_update_block(nn.Module):
    def __init__(self, config):
        super(self_update_block, self).__init__()
        num_layers = 2
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.n_patches, eps=1e-6)
        for _ in range(num_layers):
            layer = Block(config)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, refined_shape_prior):
        for layer_block in self.layer:
            refined_shape_prior = layer_block(refined_shape_prior)

        encoded = self.encoder_norm(refined_shape_prior)
        
        return encoded

class cross_update_block(nn.Module):
    def __init__(self, n_class):
        super(cross_update_block, self).__init__()
        self.n_class = n_class
        self.softmax = Softmax(dim=-1)

    def forward(self, refined_shape_prior, feature):
        class_feature = torch.matmul(feature.flatten(2), refined_shape_prior.flatten(2).transpose(-1, -2))
        # scale
        class_feature = class_feature / math.sqrt(self.n_class)
        class_feature = self.softmax(class_feature)

        class_feature = torch.einsum("ijk, iklhw->ijlhw", class_feature, refined_shape_prior)
        class_feature = feature + class_feature
        return class_feature


        
class Attention(nn.Module):
    def __init__(self, config):
        super(Attention, self).__init__()
        self.num_attention_heads = config.transformer.num_heads
        self.attention_head_size = int(config.n_patches / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.n_patches, config.n_patches)
        self.key = Linear(config.n_patches, config.n_patches)
        self.value = Linear(config.n_patches, config.n_patches)

        self.out = Linear(config.n_patches, config.n_patches)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)
        self.position_embeddings = nn.Parameter(torch.randn(1, self.num_attention_heads, config.n_classes, config.n_classes))

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # print(mixed_query_layer.shape)
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        
        attention_scores = attention_scores + self.position_embeddings                        # RPE
        
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        # weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output


class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.n_patches, config.hidden_size)
        self.fc2 = Linear(config.hidden_size, config.n_patches)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x       
        
        

class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()

        self.attention_norm = LayerNorm(config.n_patches, eps=1e-6)
        self.ffn_norm = LayerNorm(config.n_patches, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config)

    def forward(self, x):
        h = x                                              
        x = self.attention_norm(x)                         
        x = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x




class SPM(nn.Module):
    def __init__(self, config, in_channel, scale):
        super(SPM, self).__init__()
        self.scale = scale
        self.SUB = self_update_block(config)
        self.CUB  = cross_update_block(config.n_classes)
        self.resblock1 = DecoderResBlock(in_channel, in_channel)
        self.resblock2 = DecoderResBlock(in_channel, in_channel)
        self.resblock3 = DecoderResBlock(in_channel, config.n_classes)

        self.h = config.h
        self.w = config.w
        self.l = config.l
        self.dim = in_channel
        
        
    def forward(self, feature, refined_shape_prior):
        # print(refined_shape_prior.size())
        b, n_class, _ = refined_shape_prior.size()
        B = feature.size()[0]
        refined_shape_prior = self.SUB(refined_shape_prior)
        previous_class_center = refined_shape_prior
        refined_shape_prior = F.interpolate(refined_shape_prior.contiguous().view(b, n_class, self.h, self.w, self.l), scale_factor=self.scale, mode="trilinear")
        feature = self.resblock1(feature)
        feature = self.resblock2(feature)
        
        class_feature = self.CUB(refined_shape_prior, feature)
        
        # b * N * H/i * W/i * L/i
        refined_shape_prior = F.interpolate(self.resblock3(class_feature), scale_factor=(1.0 / self.scale[0], 1.0 / self.scale[1], 1.0 / self.scale[2]), mode="trilinear").flatten(2) + previous_class_center

        return class_feature, refined_shape_prior



class Conv3dbn(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )

        bn = nn.BatchNorm3d(out_channels)

        super(Conv3dbn, self).__init__(conv, bn)

class Conv3dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        bn = nn.BatchNorm3d(out_channels)

        super(Conv3dReLU, self).__init__(conv, bn, relu)

class DecoderResBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            use_batchnorm=True,
    ):
        super().__init__()
        self.conv1 = Conv3dReLU(
            in_channels,
            out_channels,
            kernel_size=1,
            padding=0,
            use_batchnorm=use_batchnorm,
        )
        self.conv2 = Conv3dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )

        self.conv3 = Conv3dbn(
            in_channels,
            out_channels,
            kernel_size=1,
            padding=0,
            use_batchnorm=use_batchnorm,
        )

        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, skip=None):

        feature_in = self.conv3(x)

        x = self.conv1(x)
        x = self.conv2(x)

        x = x + feature_in
        x = self.relu(x)
        # x = self.se_block(x)

        return x

猜你喜欢

转载自blog.csdn.net/m0_61899108/article/details/131160069