PSPnet网络结构搭建

论文地址:Pyramid Scene Parsing Network


整体结构图

超详细结构图(Mobilenetv2主干)

 主干网络搭建

import math

import torch.nn as nn
from torch.hub import load_state_dict_from_url

BatchNorm2d = nn.BatchNorm2d

def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

class InvertedResidual(nn.Module): #InvertedResidual模块
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2] #断言

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = (self.stride == 1 and inp == oup) #self.stride == 1 and inp == oup两个条件都成立则为True
        #----------------------------------------------------#
        #   利用1x1卷积根据输入进来的通道数进行通道数上升
        #----------------------------------------------------#
        if expand_ratio == 1:    #expand_ratio为1不需要经过第一个conv-BN-relu
            self.conv = nn.Sequential(
                #----------------------------------------------------#
                #   利用深度可分离卷积进行特征提取
                #----------------------------------------------------#
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                #----------------------------------------------------#
                #   利用1x1的卷积进行通道数的下降
                #----------------------------------------------------#
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                BatchNorm2d(oup),
            )
        else:   #expand_ratio不为1需要经过第一个conv-BN-relu
            self.conv = nn.Sequential(
                #----------------------------------------------------#
                #   利用1x1卷积根据输入进来的通道数进行通道数上升
                #----------------------------------------------------#
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                #----------------------------------------------------#
                #   利用深度可分离卷积进行特征提取
                #----------------------------------------------------#
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                #----------------------------------------------------#
                #   利用1x1的卷积进行通道数的下降
                #----------------------------------------------------#
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:  #use_res_connect为True则有shortcut
            return x + self.conv(x)
        else:                     #否则没有shortcut
            return self.conv(x)

class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280

        interverted_residual_setting = [
            #假定输入图片为473*473*3
            #一共有七个大模块
            #第一列是expand_ratio,第二列是每一个大模块的输出通道,
            #第三列是每一个大模块中包含InvertedResidual的数量,第四列是每一个大模块中第一个InvertedResidual的stride值
            # t, c, n, s
            # 237,237,32 -> 237,237,16
            [1, 16, 1, 1],
            # 237,237,16 -> 119,119,24
            [6, 24, 2, 2],
            # 119,119,24 -> 60,60,32
            [6, 32, 3, 2],
            # 60,60,32 -> 30,30,64
            [6, 64, 4, 2],
            # 30,30,64 -> 30,30,96
            [6, 96, 3, 1],
            # 30,30,96 -> 15,15,160
            [6, 160, 3, 2],
            # 15,15,160 -> 15,15,320
            [6, 320, 1, 1],
        ]
        
        assert input_size % 32 == 0 #输入图片大小为32倍数
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
        #473,473,3 -> 237,237,32 此处stride=2 即先对输入进来的图片进行高宽压缩
        self.features = [conv_bn(3, input_channel, 2)] #features为所有层的集合
        
        # 根据上述列表进行循环,构建mobilenetv2的结构
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n): #每一个大模块中第一个InvertedResidual的stride值为s,剩余InvertedResidual模块的stride值都为1
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel #上一个模块的output_channel就是下一个模块的input_channel
                
        # mobilenetv2结构的收尾工作
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        self.features = nn.Sequential(*self.features)

        # 最后的分类部分
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, n_class),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()



def mobilenetv2(pretrained=False, **kwargs):
    model = MobileNetV2(n_class=1000, **kwargs)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url('https://github.com/bubbliiiing/pspnet-pytorch/releases/download/v1.0/mobilenet_v2.pth.tar', "./model_data"), strict=False)
    return model

整体网络搭建

import torch
import torch.nn.functional as F
from torch import nn

from nets.mobilenetv2 import mobilenetv2
from nets.resnet import resnet50


class Resnet(nn.Module):
    def __init__(self, dilate_scale=8, pretrained=True):
        super(Resnet, self).__init__()
        from functools import partial
        model = resnet50(pretrained)

        #--------------------------------------------------------------------------------------------#
        #   根据下采样因子修改卷积的步长与膨胀系数
        #   当downsample_factor=16的时候,我们最终获得两个特征层,shape分别是:30,30,1024和30,30,2048
        #--------------------------------------------------------------------------------------------#
        if dilate_scale == 8:
            model.layer3.apply(partial(self._nostride_dilate, dilate=2))
            model.layer4.apply(partial(self._nostride_dilate, dilate=4))
        elif dilate_scale == 16:
            model.layer4.apply(partial(self._nostride_dilate, dilate=2))

        self.conv1 = model.conv1[0]
        self.bn1 = model.conv1[1]
        self.relu1 = model.conv1[2]
        self.conv2 = model.conv1[3]
        self.bn2 = model.conv1[4]
        self.relu2 = model.conv1[5]
        self.conv3 = model.conv1[6]
        self.bn3 = model.bn1
        self.relu3 = model.relu
        self.maxpool = model.maxpool
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4

    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x_aux = self.layer3(x)
        x = self.layer4(x_aux)
        return x_aux, x

class MobileNetV2(nn.Module):
    def __init__(self, downsample_factor=8, pretrained=True):
        super(MobileNetV2, self).__init__()
        from functools import partial
        
        model = mobilenetv2(pretrained)
        # self.features共有19个模块,第一个是conv_bn,然后是17个InvertedResidual模块,最后一个conv_1x1_bn模块
        self.features = model.features[:-1]
        # [:-1]代表除了最后一个其他都要,这样self.features就包含18个模块,第一个是conv_bn,然后是17个InvertedResidual模块
        self.total_idx = len(self.features)
        self.down_idx = [2, 4, 7, 14] #[2, 4, 7, 14]为self.features列表中有下采样的InvertedResidual模块的索引,不包含第一个conv_bn下采样模块

        #--------------------------------------------------------------------------------------------#
        #   根据下采样因子修改卷积的步长与膨胀系数
        #   当downsample_factor=16的时候,我们最终获得两个特征层,shape分别是:30,30,320和30,30,96
        #--------------------------------------------------------------------------------------------#
        if downsample_factor == 8:   #downsample_factor == 8时,会进行三次下采样,需要对索引为7,14的两个下采样模块参数进行修改
            for i in range(self.down_idx[-2], self.down_idx[-1]):
                # self.features中第7,8,9,10,11,12,13个模块中膨胀系数为2,也即interverted_residual_setting中4,5行
                self.features[i].apply(partial(self._nostride_dilate, dilate=2))
            for i in range(self.down_idx[-1], self.total_idx):
                # self.features中第14,15,16,17个模块中膨胀系数为4,也即interverted_residual_setting中倒数后两行
                self.features[i].apply(partial(self._nostride_dilate, dilate=4))
        elif downsample_factor == 16:   #downsample_factor == 16时,会进行四次下采样,需要对索引为14的下采样模块参数进行修改
            # self.features中第14,15,16,17个模块中膨胀系数为2,也即interverted_residual_setting中倒数后两行
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(partial(self._nostride_dilate, dilate=2))
        
    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        x_aux = self.features[:14](x) #即interverted_residual_setting中倒数第三层的输出
        x = self.features[14:](x_aux) #即interverted_residual_setting中最后一层的输出
        #x的shape为30*30*320
        return x_aux, x
 
class _PSPModule(nn.Module):
    def __init__(self, in_channels, pool_sizes, norm_layer):
        super(_PSPModule, self).__init__()
        out_channels = in_channels // len(pool_sizes)
        #-----------------------------------------------------#
        #   分区域进行平均池化
        #   30, 30, 320 + 30, 30, 80 + 30, 30, 80 + 30, 30, 80 + 30, 30, 80  = 30, 30, 640
        #-----------------------------------------------------#
        # pool_sizes=[1, 2, 3, 6]
        self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, pool_size, norm_layer) for pool_size in pool_sizes])
        
        # 30, 30, 640 -> 30, 30, 80
        # 最后的一个conv,len(pool_sizes)=4
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels + (out_channels * len(pool_sizes)), out_channels, kernel_size=3, padding=1, bias=False),
            norm_layer(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1)
        )

    def _make_stages(self, in_channels, out_channels, bin_sz, norm_layer):
        # 先经过自适应平均池化,再经过conv-bn-relu
        prior = nn.AdaptiveAvgPool2d(output_size=bin_sz) #bin_sz分别为1,2,3,6
        conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        bn = norm_layer(out_channels)
        relu = nn.ReLU(inplace=True)
        return nn.Sequential(prior, conv, bn, relu) #保存到self.stages中
    
    def forward(self, features):
        #得到传入特征层features的高宽
        h, w = features.size()[2], features.size()[3]
        #先把features存入列表pyramids中
        pyramids = [features]
        #将features通过stage,然后再双线性插值上采样,上采样到(h, w)大小,最后将四种特征图通过.extend存入列表pyramids中
        pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages])
        #pyramids中五个特征层拼接,然后经过bottleneck
        output = self.bottleneck(torch.cat(pyramids, dim=1))
        return output


class PSPNet(nn.Module):
    def __init__(self, num_classes, downsample_factor, backbone="resnet50", pretrained=True, aux_branch=True):
        super(PSPNet, self).__init__()
        norm_layer = nn.BatchNorm2d
        if backbone=="resnet50":
            self.backbone = Resnet(downsample_factor, pretrained)
            aux_channel = 1024
            out_channel = 2048
        elif backbone=="mobilenet":
            #----------------------------------#
            #   获得两个特征层
            #   f4为辅助分支    [30,30,96]
            #   o为主干部分     [30,30,320]
            #----------------------------------#
            self.backbone = MobileNetV2(downsample_factor, pretrained)
            aux_channel = 96
            out_channel = 320
        else:
            raise ValueError('Unsupported backbone - `{}`, Use mobilenet, resnet50.'.format(backbone))

        #--------------------------------------------------------------#
        #	PSP模块,分区域进行池化
        #   分别分割成1x1的区域,2x2的区域,3x3的区域,6x6的区域
        #   30,30,320 -> 30,30,80 -> 30,30,21
        #--------------------------------------------------------------#
        self.master_branch = nn.Sequential(
            _PSPModule(out_channel, pool_sizes=[1, 2, 3, 6], norm_layer=norm_layer),
            nn.Conv2d(out_channel//4, num_classes, kernel_size=1)
        )

        self.aux_branch = aux_branch

        if self.aux_branch:
            #---------------------------------------------------#
            #	利用特征获得预测结果
            #   30, 30, 96 -> 30, 30, 40 -> 30, 30, 21
            #---------------------------------------------------#
            self.auxiliary_branch = nn.Sequential(
                nn.Conv2d(aux_channel, out_channel//8, kernel_size=3, padding=1, bias=False),
                norm_layer(out_channel//8),
                nn.ReLU(inplace=True),
                nn.Dropout2d(0.1),
                nn.Conv2d(out_channel//8, num_classes, kernel_size=1)
            )

        self.initialize_weights(self.master_branch)

    def forward(self, x):
        input_size = (x.size()[2], x.size()[3])
        x_aux, x = self.backbone(x)
        output = self.master_branch(x)
        output = F.interpolate(output, size=input_size, mode='bilinear', align_corners=True)
        if self.aux_branch:
            output_aux = self.auxiliary_branch(x_aux)
            output_aux = F.interpolate(output_aux, size=input_size, mode='bilinear', align_corners=True)
            return output_aux, output
        else:
            return output

    def initialize_weights(self, *models):
        for model in models:
            for m in model.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1.)
                    m.bias.data.fill_(1e-4)
                elif isinstance(m, nn.Linear):
                    m.weight.data.normal_(0.0, 0.0001)
                    m.bias.data.zero_()

reference

【Pytorch 搭建自己的PSPnet语义分割平台(Bubbliiiing 深度学习 教程)】 https://www.bilibili.com/video/BV1zt4y1q7HH?p=5&share_source=copy_web&vd_source=95705b32f23f70b32dfa1721628d5874

猜你喜欢

转载自blog.csdn.net/m0_56247038/article/details/127103177