UNet代码详解

UNet代码详解

第一步,还是加载一些库

import torch 
import torch.nn as nn
import torch.nn.functional as F

创建一个卷积Block类

class UNetConvBlock(nn.Module):
    def __init__(self, in_chans, out_chans, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block=[]
        
        block.append(nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=int(padding))
        block.append(nn.ReLU())

        if batch_norm :
            block.append(nn.BatchNorm2d(out_chans))
        
        block.append(nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=int(padding))
        block.append(nn.ReLU())
        
        if batch_norm:
            block.append(nn.BatchNorm2d(out_chans))

        self.block = nn.Sequential(*block)
        
    def forward(self, x):
        out = self.block(x)
        return out

这里实现的就是,每一个stage的卷积block。如下图:

 创建上采样的Block

class UNetUpBlock(nn.Module):
    def __init__(self, in_chans, out_chans, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTransposed2d(in_chans, out_chans, kernel_size=2, stride=2)
        elif up_mode=='upsample':
            self.up == nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_chans, out_chans, kernel_size=1),
            )
        self.conv_block = UNetConvBlock(in_chans, out_chans, padding, batch_norm)

上采样有两种方式,转置卷积和双线性插值。这里可以选择,使用哪种方式实现。

    def centre_crop(self, layer, target_size):
        _,_,layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y: (diff_y + target_size[0]), diff_x: (diff_x + target_size[1])]

 这里实现的是剪裁操作,我们注意到,skip connection两边图像大小是不一样的,根据论文描述,我们需要将Encoder部分的图像剪裁到Decoder部分大小,如图所示:

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.centre_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)
        return out

 创建UNet

class UNet(nn.Module):
    def __init__(
        self, 
        in_channels=1,
        n_classes=2,
        depth=5,
        wf=6,
        padding=False,
        batch_norm=False,
        up_mode='upconv'
    ):
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels

        self.down_path = nn.ModuleList()

        for i in range(depth): # 0 1 2 3 4
            self.down_path.append(UNetConvBlock(prev_channels, 2**(wf + i), padding, batch_norm)
            prev_channels=2**(wf+i) # 这里wf+i计算channels数量
        
        self.up_path = nn.ModuleList()
        
        for i in resersed(range(depth-1)):
            self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode, padding, batch_norm)
            prev_channels = 2**(wf+i)
           
        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks=[]
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x ,2)
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i -1])
        return self.last(x)

        
发布了85 篇原创文章 · 获赞 17 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/lun55423/article/details/104976561
今日推荐