unet 网络结构

unet 是15年提出的用于解决医学图像分割问题。unet有两部分组成。左边部分可以看出是特征提取网络,用于提取图像的抽象特征。右边可以看作是特征融合操作。与传统的FCN相比,unet使用是使用特征拼接实现特征的融合。unet 通过特征融合操作,实现了浅层的低分辨率(越底层的信息含有越多的细节信息)和深层的高分辨率信息(深层信息含有更多的抽象特征)的融合,充分了利用了图像的上下文信息,使用对称的U型结构使得特征融合的更加彻底。

上图是unet 的网络结构图。其中蓝色方框代表的是特征图。可以看到,左边部分首先进行两层卷积然后进行下采样来提取特征。右边,通过上采样操作后与相应的左边的特征图进行拼接操作。 

from torch import nn
import torch
from torch.nn import functional as F


class Conv_Block(nn.Module):  # 卷积
    def __init__(self, in_channel, out_channel):
        super(Conv_Block, self).__init__()
        self.layer = nn.Sequential(
            ####填充的方式,填充的大小,padding_mode 设置填充的方式   ###这里卷积图片的大小没有发生改变
            nn.Conv2d(in_channel, out_channel, 3, 1, 1, padding_mode='reflect',
                      bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect',
                      bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layer(x)


class DownSample(nn.Module):  # 下采样  使用卷积步长为2进行下采样
    def __init__(self, channel):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(channel, channel, 3, 2, 1, padding_mode='reflect',
                      bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()

        )   ###下采样 通道不变,图像大小减半

    def forward(self, x):
        return self.layer(x)


class UpSample(nn.Module):  # 上采样(最邻近插值法)
    def __init__(self, channel):
        super(UpSample, self).__init__()
        self.layer = nn.Conv2d(channel, channel // 2, 1, 1)   ###上采样 这里首先运用1*1卷积进行降维

    def forward(self, x, feature_map):
        up = F.interpolate(x, scale_factor=2, mode='nearest') ###上采样插值
        out = self.layer(up)
        return torch.cat((out, feature_map), dim=1)  ###s上采样 首先将x上采样,通道减半
    ###所以上采样,图像大小增加,通道减半,下采样,图像大小减半,通道增加


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.c1 = Conv_Block(3, 64)
        self.d1 = DownSample(64)
        self.c2 = Conv_Block(64, 128)  ###通道增加
        self.d2 = DownSample(128)  ##下采样  
        self.c3 = Conv_Block(128, 256)
        self.d3 = DownSample(256)
        self.c4 = Conv_Block(256, 512)
        self.d4 = DownSample(512)
        self.c5 = Conv_Block(512, 1024)
        
        self.u1 = UpSample(1024)
        self.c6 = Conv_Block(1024, 512)   
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)

        self.out = nn.Conv2d(64, 3, 3, 1, 1)
        self.Th = nn.Sigmoid()

    def forward(self, x):
        R1 = self.c1(x)   ###通道数64
        R2 = self.c2(self.d1(R1))  ###  下采样 图片大小减半,通道数增加 128
        R3 = self.c3(self.d2(R2))  ### 下采样      256
        R4 = self.c4(self.d3(R3))   ###    512unet 经过四次上采样, 四次下采样,得到五个不同分辨率的图像
        R5 = self.c5(self.d4(R4))   ### 1024

        O1 = self.c6(self.u1(R5, R4))  ###首先将R5上采样然后与R4进行特征融合  512
        O2 = self.c7(self.u2(O1, R3))  ##          256
        O3 = self.c8(self.u3(O2, R2))   
        O4 = self.c9(self.u4(O3, R1))

        return self.Th(self.out(O4))


if __name__ == "__main__":
    x = torch.randn(2, 3, 256, 256)
    net = UNet()
    print(net(x).shape)

猜你喜欢

转载自blog.csdn.net/qq_40107571/article/details/127150532