基于pytorch编写Unet进行细胞结构分割

论文地址:U-Net: Convolutional Networks for Biomedical Image Segmentation
代码地址:https://github.com/laisimiao/Unet
Unet想必大家都听过它的大名,在医学图像分割方向非常著名的一篇论文,我也是怀着好奇心,以其为我想对分割任务有个初步认识为动机,实现了一下论文中说的那个挑战赛(训练数据少,训练时间短):ISBI Challenge: Segmentation of neuronal structures in EM stacks,关于这个数据集:论文中是这样说的:

The training data is a set of 30 images (512x512 pixels) from serial section transmission electron microscopy of the Drosophila first instar larva ventral nerve cord (VNC).
大意就是:训练数据有30张,分辨率为512x512,这些图片是果蝇的电镜图。

看下面这个动图,就可以了解要做一个什么任务:给一张细胞结构图,我们要把他互相分割开来。
在这里插入图片描述
其实我理解的分割任务就是对每一个像素点分类,所以它的标签也是和输入图像的分辨率也是一样的,对于这里的果蝇细胞图,就是分成两类就好了:细胞膜处为黑色,细胞内为白色,所以我们预测结果的feature map形状可以为2x512x512(对于VOC和COCO多分类的也是一样的,可以预测CxHxW形状的feature map,C为目标类别数,H和W分别为输入图像的高和宽),其中第一个通道为细胞膜处的mask,第二个通道为细胞内的mask,后来发现对于这种二分类的问题其实只输出一个通道也行,即1x512x512,因为最后除了黑就是白,而且原图很接近灰度图,转化为单通道后没有多大区别,这样输入输出都是单通道的,计算pixel-wise的logistic loss也很方便代码实现。
放一张Unet网络结构图,用pytorch还是比较容易实现的,只不过我这里实现的contracting path和expansive path的分辨率对应是一样的,而不是像论文图中的从contracting path crop一块与expansive path concat的:
Unet网络结构图
写一点理解和没太清楚的地方:

  • 这里网络训练出来以后就像是一个纹理检测器,输入一张测试图片后,遇到细胞边界处它的激活值就大,经过sigmoid函数之后就趋向于1,也就是黑色;其他的地方激活值就低,经过sigmoid函数之后就趋向于0,也就是白色。
  • 对于.tif格式的图片,读入好像只能通过Image和TIFF(网上看来的),但是我一读入这个细胞图片后,Image的mode就是L,也就是灰度图,所以我就不知道原来是3通道的还是1通道的,但是这个对于后面不影响

具体实现可以去看代码,提供了keras2.2.4和pytorch1.0两个版本,实际训练的时候发现keras版的当训练的batch size=2时,GPU显存占用就达到了10GB左右,但是pytorch里面只用了4GB左右,因为pytorch版本的decode部分是通过上采样而不是转置卷积实现的,这里放两张效果图:
这是keras版的结果图
这是pytorch版的结果图
附一个评价指标dice coefficient的链接:医学图像分割之 Dice Loss

猜你喜欢

转载自blog.csdn.net/laizi_laizi/article/details/103863756