论文笔记-UNeXt: MLP-based Rapid Medical ImageSegmentation Network

论文:https://arxiv.org/abs/2203.04967

代码:https://github.com/jeya-maria-jose/UNeXt-pytorch

1.摘要

Unet网络及其扩展近年来是领先的医学图像分割方法。但是这些网络参数多,计算复杂,速度慢,不能应用于实时的快速图像分割。

本文总结:

(1)提出一种基于MLP(多层感知机)的图像分割网络Unext,即卷积+MLP的结构。

(2)提出一个标记MLP块tokenized MLP block),标记和投影卷积特征。

(3)为提高性能,提出输入mlp时shift输入的channel,为了学习local dependencies。

(3)网络中包括各级编码器和解码器之间的跳跃连接

(4)与目前最先进的医学图像分割架构相比,UNeXt的参数数量减少了72x,计算复杂度降低了68x,推理速度提高了10x,同时也获得了更好的分割性能,又快又好

2 引言

In summary, this paper makes the following contributions:
1) We propose UNeXt, the first convolutional MLP-based network for image segmentation.
2) We propose a novel tokenized MLP block with axial shifts to effiffifficiently learn a
good representation at the latent space.
3) We successfully improve the performance on medical image segmentation tasks while having less parameters, high inference speed, and low computational complexity.

3 UNeXt

网络设计

encoder-decoder 结构,有两个阶段:

 (1)卷积阶段

  (2) tokenized MLP阶段。

编码器,其中前3个块是卷积,下2个是tokenized MLP块。解码器有2个tokenized MLP块,后面跟着3个卷积块。每个编码器块减少特征分辨率2倍,每个解码器块增加特征分辨率2。跳跃连接也被应用在了编码器和解码器之间。

作者减少了通道数,C1 = 32, C2 = 64, C3 = 128, C4 = 160, and C5 = 256(Unet 通道数:64,128,256,512,1024)确实减少了很多计算量。

 3.1卷积阶段

每个卷积层都是一个卷积层(Unet是两个卷积层),BN,Relu组合。3*3卷积核,stride=1,padding=1。编码器的conv块使用带有池窗口2×2的max-pooling层,而解码器的conv块使用双线性插值层对特征图进行上采样。我们使用双线性插值而不是转置卷积,因为转置卷积基本上是可学习的上采样,会导致产生更多可学习的参数。

3.2 Shifted MLP

在shifted MLP中,在tokenize之前,首先移动conv features通道的轴。这有助于MLP只关注conv特征的某些位置,从而诱导块的位置。与Swin transformer类似,在swin中引入基于窗口的注意力,以向完全全局的模型添加更多的局部性。由于Tokenized MLP块有2个mlp,我们在一个块中跨越宽度移动特征,在另一个块中跨越高度移动特征,就像轴向注意力中一样。我们对这些特征做了h个划分,并根据指定的轴通过j个位置移动它们。这有助于我们创建随机窗口,引入沿轴线的局部性。

如下图:灰色是特征块的位置,白色是移动之后的padding。

 3.3 Tokenized MLP Stage

 在Tokenized MLP块中,

  • 首先shift features并投影到token中:首先用3x3conv把特征投影到E维,其中E是embadding维度(token的数量),它是一个超参数。然后我们将这些token传递给一个shifted MLP(跨越width)。
  • 接着,特征通过 DW-Conv传递,之后使用了GELU激活层;
  • 为什么使用DWConv?

    1) 它有助于编码MLP特征的位置信息。并且实际性能优于标准的位置编码技术。当测试或者训练分辨率不相同时,像ViT中的位置编码技术需要插值,通常会导致性能下降。

    2) DWConv使用更少的参数,因此提高了效率。

  • 接着,通过另一个shifted MLP(跨越height)传递特征,该mlp把特征的尺寸从H转换为了O。我们在这里使用一个残差连接,并将原始标记添加为残差。
  • 然后我们利用layer norm(LN),并将输出特征传递到下一个块。LN比BN更可取,因为它更有意义的是沿着token进行规范化,而不是在Tokenized MLP块的整个批处理中进行规范化。

Tokenized block的计算

4 实验

4.1 对比实验

 4.2 消融实验

5 总结

在自己最近做的划痕检测的项目上试了一下,亲测又快又好,512*512的图在1080Ti上跑大约6ms,下边是在自己数据集上的效果图

猜你喜欢

转载自blog.csdn.net/Bolly_He/article/details/124021535