ResNeXt模型——pytorch实现

论文传送门:Aggregated Residual Transformations for Deep Neural Networks
前置文章:ResNet模型——pytorch实现

ResNeXt的改进:

在ResNet网络的基础上,对Residual结构进行了改进。主要改进是将输入特征分成C个组,进行C个Conv paths计算得到C个输出,最后输出由C个输出相加(再加上shortcut x)。实现方法是将Residual结构中的第二个卷积替换为Group Convolution(组卷积)。

改进的Residual结构:

Residual结构
注意到,以下三种结构在数学上等价,则仅需将第二个Conv替换为Group Conv。

等价结构
对于ResNet18、34所使用的Residual结构,进行多paths结构的改进没有太大意义,其数学上等价于增加卷积核个数(输出通道数)。
等价结构

Group Convolution:

标准卷积是针对输入特征图的全部通道进行计算,即每个卷积核的长度与特征图的通道数保持一致。
组卷积是指将输入特征图在通道维度进行分组,对每组特征图单独进行标准卷积,然后将各组计算的结果拼接在一起,得到最终输出。
可以发现,组卷积相比于标准卷积参数量更小。
注意,当组卷积的组数等于输入特征图的通道数时,此时组卷积就是深度可分离卷积(Depthwise Separable Convolution);当组卷积的组数等于1时,此时组卷积就是标准卷积。

ResNeXt50的结构:

整体结构与ResNet50保持一致,其中的Residual结构中,卷积的中间通道数变为原来的2倍,第二个卷积使用Group Convolution,C表示组数。
ResNeXt结构

import torch
import torch.nn as nn


class BasicBlock(nn.Module):  # 定义残差块,resnet18、resnet34使用此残差块
    expansion = 1  # 残差操作维度变化倍数

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):  # 初始化方法
        super(BasicBlock, self).__init__()  # 继承初始化方法

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride,
                               padding=1, bias=False)  # conv操作
        self.bn1 = nn.BatchNorm2d(num_features=out_channel)  # bn操作
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1,
                               bias=False)  # conv操作
        self.bn2 = nn.BatchNorm2d(num_features=out_channel)  # bn操作
        self.relu = nn.ReLU(inplace=True)  # relu激活函数
        self.downsample = downsample  # 是否下采样

    def forward(self, x):  # 前传函数
        identity = x  # 原始x
        if self.downsample:  # 如果下采样
            identity = self.downsample(x)  # 残差边存在conv操作,x-->x'

        x = self.conv1(x)  # conv操作
        x = self.bn1(x)  # bn操作
        x = self.relu(x)  # relu激活函数
        x = self.conv2(x)  # conv操作
        x = self.bn2(x)  # bn操作

        x += identity  # F(x)+x/x'
        x = self.relu(x)  # relu激活函数

        return x


class Bottleneck(nn.Module):  # 定义残差块,renet50、resnet101、resnet152使用此残差块
    expansion = 4  # 残差操作维度变化倍数

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, groups=1, base_width=64):  # 初始化方法
        super(Bottleneck, self).__init__()  # 继承初始化方法
        width = int(in_channel * (base_width / 64.0)) * groups  # F(x)第二个卷积的通道数
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, kernel_size=1, stride=1,
                               bias=False)  # conv操作
        self.bn1 = nn.BatchNorm2d(num_features=width)  # bn操作
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups, kernel_size=3, stride=stride,
                               padding=1, bias=False)  # conv操作,若为ResNeXt网络,则这里为group conv操作
        self.bn2 = nn.BatchNorm2d(num_features=width)  # bn操作
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel * self.expansion, kernel_size=1, stride=1,
                               bias=False)  # conv操作
        self.bn3 = nn.BatchNorm2d(num_features=out_channel * self.expansion)  # bn操作

        self.relu = nn.ReLU(inplace=True)  # relu激活函数
        self.downsample = downsample  # 是否下采样

    def forward(self, x):  # 前传函数
        identity = x  # 原始x
        if self.downsample:  # 如果下采样
            identity = self.downsample(x)  # 残差边存在conv操作,x-->x'

        x = self.conv1(x)  # conv操作
        x = self.bn1(x)  # bn操作
        x = self.relu(x)  # relu激活函数

        x = self.conv2(x)  # conv操作
        x = self.bn2(x)  # bn操作
        x = self.relu(x)  # relu激活函数

        x = self.conv3(x)  # conv操作
        x = self.bn3(x)  # bn操作

        x += identity  # F(x)+x/x'
        x = self.relu(x)  # relu激活函数

        return x


class ResNet(nn.Module):  # 定义resnet模型
    def __init__(self, block, layers, num_classes, in_channel=3, channel=64, groups=1, base_width=64):  # 初始化方法
        super(ResNet, self).__init__()  # 继承初始化方法
        self.in_channel = in_channel  # 输入通道数,应为图片通道数
        self.channel = channel  # 第一次conv输出通道数
        self.groups = groups  # 分组卷积组数,ResNet默认为1,即不采用组卷积,ResNeXt不为1,采用组卷积
        self.base_width = base_width  # 每组通道数(第一个Block)

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=channel, kernel_size=7, stride=2, padding=3,
                               bias=False)  # conv操作
        self.bn1 = nn.BatchNorm2d(self.channel)  # bn操作

        self.relu = nn.ReLU(inplace=True)  # relu激活函数
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # maxpool操作

        self.layer1 = self._make_layer(block, 64, layers[0])  # 第一块残差集合,由基本的残差块组成
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  # 第二块残差集合,由基本的残差块组成
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  # 第三块残差集合,由基本的残差块组成
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)  # 第四块残差集合,由基本的残差块组成

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))  # avgpool操作
        self.fc = nn.Linear(512 * block.expansion, num_classes)  # linear映射

        for m in self.modules():  # 遍历模型结构
            if isinstance(m, nn.Conv2d):  # 如果当前结构是卷积操作
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")  # 使用kaiming初始化方法

    def _make_layer(self, block, channel, blocks, stride=1):  # 定义函数,用于生成模型结构
        downsample = None  # 默认不对原始x进行操作

        if stride != 1 or self.channel != channel * block.expansion:  # 如果卷积步长不为1或卷积前后通道数不一致,则需要对原始x进行操作
            downsample = nn.Sequential(
                nn.Conv2d(in_channels=self.channel, out_channels=channel * block.expansion, kernel_size=1,  # conv操作
                          stride=stride, bias=False),
                nn.BatchNorm2d(num_features=channel * block.expansion)  # bn操作
            )
        layers = []  # 列表用于存放模型结构

        layers.append(block(self.channel, channel, downsample=downsample, stride=stride, groups=self.groups,
                            base_width=self.base_width))  # 模型追加block结构
        self.channel = channel * block.expansion  # 通道数转换为卷积后输出通道数
        for _ in range(1, blocks):  # 进行blocks次循环
            layers.append(block(self.channel, channel, groups=self.groups, base_width=self.base_width))  # 模型追加block结构
        return nn.Sequential(*layers)  # 返回模型结构

    def forward(self, x):  # 前传函数
        x = self.conv1(x)  # conv操作
        x = self.bn1(x)  # bn操作
        x = self.relu(x)  # relu激活函数
        x = self.maxpool(x)  # maxpool激活函数

        x = self.layer1(x)  # 第一块残差集合
        x = self.layer2(x)  # 第二块残差集合
        x = self.layer3(x)  # 第三块残差集合
        x = self.layer4(x)  # 第四块残差集合

        x = self.avgpool(x)  # avgpool操作
        x = torch.flatten(x, 1)  # 将多维特征映射成一维特征向量
        x = self.fc(x)  # linear映射

        return x


def resnet18(num_classes=2):  # 定义函数,生成resnet18模型
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)  # resnet18使用BasicBlock基础残差块,四块残差集合使用的残差块数量为2,2,2,2


def resnet34(num_classes=2):  # 定义函数,生成resnet34模型
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)  # resnet18使用BasicBlock基础残差块,四块残差集合使用的残差块数量为3,4,6,3


def resnet50(num_classes=2):  # 定义函数,生成resnet50模型
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)  # resnet18使用Bottleneck基础残差块,四块残差集合使用的残差块数量为3,4,6,3


def resnet101(num_classes=2):  # 定义函数,生成resnet101模型
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes)  # resnet18使用Bottleneck基础残差块,四块残差集合使用的残差块数量为3, 4, 23, 3


def resnet152(num_classes=2):  # 定义函数,生成resnet152模型
    return ResNet(Bottleneck, [3, 8, 36, 3],
                  num_classes=num_classes)  # resnet18使用Bottleneck基础残差块,四块残差集合使用的残差块数量为3, 8, 36, 3


def resnext50_32x4d(num_classes=2):  # 定义函数,生成resnext50模型
    return ResNet(Bottleneck, [3, 4, 6, 3],  # resnext50使用Bottleneck基础残差块,四块残差集合使用的残差块数量为3, 4, 6, 3
                  num_classes=num_classes,
                  groups=32,  # 组卷积组数为32
                  base_width=4)  # 每组的通道数为4(第一个Block)


def resnext101_32x8d(num_classes=2):  # 定义函数,生成resnext101模型
    return ResNet(Bottleneck, [3, 4, 23, 3],  # resnext101使用Bottleneck基础残差块,四块残差集合使用的残差块数量为3, 4, 23, 3
                  num_classes=num_classes,
                  groups=32,  # 组卷积组数为32
                  base_width=8)  # 每组的通道数为8(第一个Block)

猜你喜欢

转载自blog.csdn.net/Peach_____/article/details/128808810
今日推荐