图像分割U-Net网络【眼睛闭合度】

版权声明:我是南七小僧,微信: to_my_love ,2020年硕士毕业,寻找 自然语言处理,图像处理,软件开发等相关工作,欢迎交流思想碰撞。 https://blog.csdn.net/qq_25439417/article/details/86553597

昨天去听了医学影像的年会,看了一些别人做的癌症检测。

对之前做的眼部闭合度算法,有了新的想法,之前我通过MRCNN做目标检测+掩膜标注,效果较好,但是速度有点慢,我就在考虑能不能用一个小网络做到MRCNN,于是我把MRCNN里的ResNet101修改了一个更小的网络【自己搭建的】,效果挺好的,但是依然很慢,我怀疑可能和MRCNN本身的流程复杂有关系,先要经过RPN提出区域,再对区域进行分类和BBOX、MASK检测,导致比较慢。

现在决定往图像语义分割方向去尝试,对眼部图像做到自动分割,提取眼睛区域。

最近在研究全卷积神经网络在图像分割方面的应用,因为自己是做医学图像处理方面的工作,所以就把一个基于FCN(全卷积神经网络)的神经网络用 keras 实现了,并且用了一个医学图像的数据集进行了图像分割。

全卷积神经网络


大名鼎鼎的FCN就不多做介绍了,这里有一篇很好的博文 http://www.cnblogs.com/gujianhan/p/6030639.html。
不过还是建议把论文读一下,这样才能加深理解。

医学图像分割框架


医学图像分割主要有两种框架,一个是基于CNN的,另一个就是基于FCN的。

基于CNN 的框架


这个想法也很简单,就是对图像的每一个像素点进行分类,在每一个像素点上取一个patch,当做一幅图像,输入神经网络进行训练,举个例子:

这是一篇发表在NIPS上的论文Ciresan D, Giusti A, Gambardella L M, et al. Deep neural networks segment neuronal membranes in electron microscopy images[C]//Advances in neural information processing systems. 2012: 2843-2851.

这是一个二分类问题,把图像中所有label为0的点作为负样本,所有label为1的点作为正样本。

这种网络显然有两个缺点:

冗余太大,由于每个像素点都需要取一个patch,那么相邻的两个像素点的patch相似度是非常高的,这就导致了非常多的冗余,导致网络训练很慢。
感受野和定位精度不可兼得,当感受野选取比较大的时候,后面对应的pooling层的降维倍数就会增大,这样就会导致定位精度降低,但是如果感受野比较小,那么分类精度就会降低。


基于FCN框架


在医学图像处理领域,有一个应用很广泛的网络结构----U-net ,网络结构如下:

可以看出来,就是一个全卷积神经网络,输入和输出都是图像,没有全连接层。较浅的高分辨率层用来解决像素定位的问题,较深的层用来解决像素分类的问题。

conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(inputs)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)


conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(pool1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)


conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(pool2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)


conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv4)
# drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv5)
#drop5 = Dropout(0.5)(conv5)

up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(UpSampling2D(size = (2,2))(conv5))
merge6 = merge([conv4,up6], mode = 'concat', concat_axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv6)

up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = merge([conv3,up7], mode = 'concat', concat_axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv7)

up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = merge([conv2,up8], mode = 'concat', concat_axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = merge([conv1,up9], mode = 'concat', concat_axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)

model = Model(input = inputs, output = conv10)

model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])

猜你喜欢

转载自blog.csdn.net/qq_25439417/article/details/86553597