Python Unet网络结构pytorch简单实现+torchsummary可视化(可以直接运行)

Unet的网络结构:

根据该结构,用Pytorch实现Unet:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import torch.utils.data as Data 

seed = 2019
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
import random
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

##定义卷积核
def default_conv(in_channels,out_channels,kernel_size,bias=True):
    return nn.Conv2d(in_channels,out_channels,
                     kernel_size,padding=0, 
                     bias=bias)
##定义ReLU     
def default_relu():
    return nn.ReLU(inplace=True)

class Up_Sample(nn.Module):
    def __init__(self,in_channels,conv=default_conv,relu=default_relu):
        super(Up_Sample,self).__init__()
        
        up1 = nn.Upsample(scale_factor=2,mode='nearest')
        up2 = conv(in_channels,in_channels//2,1)
        self.module_up = nn.Sequential(up1,up2,relu())
        
    def forward(self,input_down,input_left):
        x = self.module_up(input_down)        
        dif = (input_left.shape[3] - x.shape[3])/2
        input_left = input_left[:,:,int(dif):int(dif+x.shape[3]),int(dif):int(dif+x.shape[3])]
        return torch.cat((x,input_left),1)
            
        
class Unet(nn.Module):
    def __init__(self,in_channels,out_channels,conv=default_conv,relu=default_relu,n_feats=64):
        super(Unet,self).__init__()
             
        left1 = [conv(in_channels,n_feats,3),relu(),conv(n_feats,n_feats,3)]
        left2 = [conv(n_feats,2*n_feats,3),relu(),conv(2*n_feats,2*n_feats,3)]
        left3 = [conv(2*n_feats,4*n_feats,3),relu(),conv(4*n_feats,4*n_feats,3)]
        left4 = [conv(4*n_feats,8*n_feats,3),relu(),conv(8*n_feats,8*n_feats,3)]
        bottom = [conv(8*n_feats,16*n_feats,3),relu(),conv(16*n_feats,16*n_feats,3)]
        
        right1 = [conv(2*n_feats,n_feats,3),relu(),conv(n_feats,n_feats,3)]
        right2 = [conv(4*n_feats,2*n_feats,3),relu(),conv(2*n_feats,2*n_feats,3)]
        right3 = [conv(8*n_feats,4*n_feats,3),relu(),conv(4*n_feats,4*n_feats,3)]
        right4 = [conv(16*n_feats,8*n_feats,3),relu(),conv(8*n_feats,8*n_feats,3)]
        
        self.left1 = nn.Sequential(*left1)       
        self.left2 = nn.Sequential(*left2)       
        self.left3 = nn.Sequential(*left3)       
        self.left4 = nn.Sequential(*left4)
        self.bottom = nn.Sequential(*bottom)
        
        self.right1 = nn.Sequential(*right1)
        self.right2 = nn.Sequential(*right2)
        self.right3 = nn.Sequential(*right3)
        self.right4 = nn.Sequential(*right4)
        
        self.tail = conv(n_feats,out_channels,1)
        
        down = []
        for layer in range(4):
            down.append(nn.MaxPool2d(kernel_size = 1,stride = 2))
        self.down = nn.Sequential(*down)
        
        up = nn.ModuleList()
        for layer in range(4):
            up.append(Up_Sample(in_channels=(2**(layer+1))*n_feats))
        self.up = nn.Sequential(*up)
        
    def forward(self,x):
        x1 = self.left1(x)
        x1d = self.down[0](x1)
        x2 = self.left2(x1d)
        x2d = self.down[1](x2)
        x3 = self.left3(x2d)
        x3d = self.down[2](x3)
        x4 = self.left4(x3d)
        x4d = self.down[3](x4)
        
        x_b = self.bottom(x4d)
        
        y4d = self.up[3](x_b,x4)
        y3 = self.right4(y4d)
        y3d = self.up[2](y3,x3)
        y2 = self.right3(y3d)
        y2d = self.up[1](y2,x2)
        y1 = self.right2(y2d)
        y1d = self.up[0](y1,x1)
        y = self.right1(y1d)
        
        out = self.tail(y)
        return out
    
def main():
    model = Unet(in_channels=1,out_channels=2)
    
    from torchsummary import summary    
    summary(model.cuda(), (1, 572, 572))
    
if __name__=='__main__':
    main()

打印模型:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 570, 570]             640
              ReLU-2         [-1, 64, 570, 570]               0
            Conv2d-3         [-1, 64, 568, 568]          36,928
         MaxPool2d-4         [-1, 64, 284, 284]               0
            Conv2d-5        [-1, 128, 282, 282]          73,856
              ReLU-6        [-1, 128, 282, 282]               0
            Conv2d-7        [-1, 128, 280, 280]         147,584
         MaxPool2d-8        [-1, 128, 140, 140]               0
            Conv2d-9        [-1, 256, 138, 138]         295,168
             ReLU-10        [-1, 256, 138, 138]               0
           Conv2d-11        [-1, 256, 136, 136]         590,080
        MaxPool2d-12          [-1, 256, 68, 68]               0
           Conv2d-13          [-1, 512, 66, 66]       1,180,160
             ReLU-14          [-1, 512, 66, 66]               0
           Conv2d-15          [-1, 512, 64, 64]       2,359,808
        MaxPool2d-16          [-1, 512, 32, 32]               0
           Conv2d-17         [-1, 1024, 30, 30]       4,719,616
             ReLU-18         [-1, 1024, 30, 30]               0
           Conv2d-19         [-1, 1024, 28, 28]       9,438,208
         Upsample-20         [-1, 1024, 56, 56]               0
           Conv2d-21          [-1, 512, 56, 56]         524,800
             ReLU-22          [-1, 512, 56, 56]               0
        Up_Sample-23         [-1, 1024, 56, 56]               0
           Conv2d-24          [-1, 512, 54, 54]       4,719,104
             ReLU-25          [-1, 512, 54, 54]               0
           Conv2d-26          [-1, 512, 52, 52]       2,359,808
         Upsample-27        [-1, 512, 104, 104]               0
           Conv2d-28        [-1, 256, 104, 104]         131,328
             ReLU-29        [-1, 256, 104, 104]               0
        Up_Sample-30        [-1, 512, 104, 104]               0
           Conv2d-31        [-1, 256, 102, 102]       1,179,904
             ReLU-32        [-1, 256, 102, 102]               0
           Conv2d-33        [-1, 256, 100, 100]         590,080
         Upsample-34        [-1, 256, 200, 200]               0
           Conv2d-35        [-1, 128, 200, 200]          32,896
             ReLU-36        [-1, 128, 200, 200]               0
        Up_Sample-37        [-1, 256, 200, 200]               0
           Conv2d-38        [-1, 128, 198, 198]         295,040
             ReLU-39        [-1, 128, 198, 198]               0
           Conv2d-40        [-1, 128, 196, 196]         147,584
         Upsample-41        [-1, 128, 392, 392]               0
           Conv2d-42         [-1, 64, 392, 392]           8,256
             ReLU-43         [-1, 64, 392, 392]               0
        Up_Sample-44        [-1, 128, 392, 392]               0
           Conv2d-45         [-1, 64, 390, 390]          73,792
             ReLU-46         [-1, 64, 390, 390]               0
           Conv2d-47         [-1, 64, 388, 388]          36,928
           Conv2d-48          [-1, 2, 388, 388]             130
================================================================
Total params: 28,941,698
Trainable params: 28,941,698
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 2275.74
Params size (MB): 110.40
Estimated Total Size (MB): 2387.39
----------------------------------------------------------------

发布了10 篇原创文章 · 获赞 9 · 访问量 586

猜你喜欢

转载自blog.csdn.net/qq_36937684/article/details/105204248