Pytorch-HED fine-tune实现

Pytorch HED (VGG16-HED and Res34-HED )

python 3 ; pytorch 0.4

基于 fine-tune 的 VGG16 或者 Resnet34 构建HED网络;并且,本人基于pytorch已经写好和训练好的网络基础上,构建属于自己简约风格的网络,例如HED网络。

'''  VGG16 network
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (18): ReLU(inplace)
  (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (20): ReLU(inplace)
  (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (22): ReLU(inplace)
  (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (25): ReLU(inplace)
  (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (27): ReLU(inplace)
  (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (29): ReLU(inplace)
  (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
'''

import torch
import torchvision
from torch import nn
# input size [256,256]
# 基于vgg16 hed

class HED_vgg16(nn.Module):
    def __init__(self,num_filters=32, pretrained=False,class_number=2):
        # Here is the function part, with no braces ()
        super().__init__()
        self.encoder = torchvision.models.vgg16(pretrained=pretrained).features
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv1=self.encoder[0:4]
        self.score1=nn.Sequential(nn.Conv2d(num_filters*2,1,1,1),nn.ReLU())# 256*256
        
        self.conv2=self.encoder[5:9]
        self.d_conv2=nn.Sequential(nn.Conv2d(num_filters*4,1,1,1),nn.ReLU())#128*128
        self.score2=nn.UpsamplingBilinear2d(scale_factor=2)#256*256
        
        self.conv3=self.encoder[10:16]
        self.d_conv3=nn.Sequential(nn.Conv2d(num_filters*8,1,1,1),nn.ReLU())#64*64
        self.score3=nn.UpsamplingBilinear2d(scale_factor=4)#256*256
        
        self.conv4=self.encoder[17:23]
        self.d_conv4=nn.Sequential(nn.Conv2d(num_filters*16,1,1,1),nn.ReLU())#32*32
        self.score4=nn.UpsamplingBilinear2d(scale_factor=8)#256*256
        
        self.conv5=self.encoder[24:30]
        self.d_conv5=nn.Sequential(nn.Conv2d(num_filters*16,1,1,1),nn.ReLU())#16*16
        self.score5=nn.UpsamplingBilinear2d(scale_factor=16)#256*256
        
        self.score=nn.Conv2d(5,class_number,1,1)# No relu
        
    def forward(self,x):
        # Here is the part that calculates the return value
        x=self.conv1(x)
        s1=self.score1(x)
        x=self.pool(x)
        
        x=self.conv2(x)
        s_x=self.d_conv2(x)
        s2=self.score2(s_x)
        x=self.pool(x)
        
        x=self.conv3(x)
        s_x=self.d_conv3(x)
        s3=self.score3(s_x)
        x=self.pool(x)
        
        x=self.conv3(x)
        s_x=self.d_conv4(x)
        s4=self.score4(s_x)
        x=self.pool(x)
        
        x=self.conv5(x)
        s_x=self.d_conv5(x)
        s5=self.score5(s_x)
        
        score=self.score(torch.cat([s1,s2,s3,s4,s5],axis=1))
        
        return score
''' you need to write softmax after model and predict output by yourself '''        
hed1=HED_vgg16()
print(hed1)
print(hed1.state_dict().keys())

######################################################################
#import torch
#import torchvision
#from torch import nn
# 基于resnet34 hed

class HED_res34(nn.Module):
    def __init__(self,num_filters=32, pretrained=False,class_number=2):
        super().__init__()
        self.encoder = torchvision.models.resnet34(pretrained=pretrained)
        
        self.pool = nn.MaxPool2d(3,2,1)
        
        #start
        self.start=nn.Sequential(self.encoder.conv1,self.encoder.bn1,self.encoder.relu)#128*128
        self.d_convs=nn.Sequential(nn.Conv2d(num_filters*2,1,1,1),nn.ReLU())
        self.scores=nn.UpsamplingBilinear2d(scale_factor=2)#256*256
        
        self.layer1=self.encoder.layer1#64*64
        self.d_conv1=nn.Sequential(nn.Conv2d(num_filters*2,1,1,1),nn.ReLU())
        self.score1=nn.UpsamplingBilinear2d(scale_factor=4)#256*256
        
        self.layer2=self.encoder.layer2#32*32
        self.d_conv2=nn.Sequential(nn.Conv2d(num_filters*4,1,1,1),nn.ReLU())
        self.score2=nn.UpsamplingBilinear2d(scale_factor=8)#256*256
        
        self.layer3=self.encoder.layer3#16*16
        self.d_conv3=nn.Sequential(nn.Conv2d(num_filters*8,1,1,1),nn.ReLU())
        self.score3=nn.UpsamplingBilinear2d(scale_factor=16)#256*256
        
        self.layer4=self.encoder.layer4#8*8
        self.d_conv4=nn.Sequential(nn.Conv2d(num_filters*16,1,1,1),nn.ReLU())
        self.score4=nn.UpsamplingBilinear2d(scale_factor=32)#256*256
        
        self.score=nn.Conv2d(5,class_number,1,1)# No relu loss_func has softmax
        
    def forward(self,x):
        x=self.start(x)
        s_x=self.d_convs(x)
        ss=self.scores(s_x)
        x=self.pool(x)
        
        x=self.layer1(x)
        s_x=self.d_conv1(x)
        s1=self.score1(s_x)
        
        x=self.layer2(x)
        s_x=self.d_conv2(x)
        s2=self.score2(s_x)
        
        x=self.layer3(x)
        s_x=self.d_conv3(x)
        s3=self.score3(s_x)
        
        x=self.layer4(x)
        s_x=self.d_conv4(x)
        s4=self.score4(s_x)
        
        score=self.score(torch.cat([s1,s2,s3,s4,ss],axis=1))
        
        return score
        
hed2=HED_res34()
print(hed2)
print(hed2.state_dict().keys())

猜你喜欢

转载自blog.csdn.net/leilei18a/article/details/80339518