resnet18 网络结构代码的复现

resnet18 网络结构代码的复现

  • 更多学习项目路线:https://github.com/xiaoaleiBLUE


前言

本次文章主要是对resnet18网路结构进行代码复现(pytorch框架),对网络架构进行简单的作图, 并比较和比别人写的resnet18网路结构,发现自己写的代码可读性还需要提高。


一、resnet18 网络结构

在这里插入图片描述

二、架构具体分析

1. 架构分析

在这里插入图片描述
在这里插入图片描述

  • 其实我们发现就是一个模块进行复用, 只是输入和输出大小不一致而已
  • 无残差连接时的一个最小单元
    在这里插入图片描述
  • 有残差连接时的一个最小单元
    在这里插入图片描述

2. 自己编写残差单元

class Resblock(nn.Module):
    def __init__(self, down_sample, in_channels, out_channels):
        super(Resblock, self).__init__()

        self.down_sample = down_sample

        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.bn_1 = nn.BatchNorm2d(out_channels)
        self.relu_1 = nn.ReLU()

        self.conv_2 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
        self.bn_2 = nn.BatchNorm2d(out_channels)
        self.relu_2 = nn.ReLU()

        self.shortcut = nn.Conv2d(in_channels, out_channels, 1, 2, 0)
        self.shortcut_bn = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()

        self.relu = nn.ReLU()

    def forward(self, x):
        if self.down_sample:
            shortcut = self.shortcut_bn(self.shortcut(x))
            x = self.relu_2(self.bn_2(self.conv_2(x)))
            x = self.relu2(self.bn2(self.conv2(x)))

            x = x + shortcut

        else:
            shortcut = x
            x = self.relu_1(self.bn_1(self.conv_1(x)))
            x = self.relu2(self.bn2(self.conv2(x)))

            x = x + shortcut

        x = self.relu(x)
        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Resblock 实例化
resblock = Resblock(True, 64, 64).to(device)

# 打印模型参数, 可视化
summary(resblock, (64, 56, 56))

3. 网上看到别人写的残差单元

class ResBlock(nn.Module):
    def __init__(self, down_sample, in_channels, out_channels):
        super(ResBlock, self).__init__()

        self.down_sample = down_sample

        if self.down_sample:
            # 需要下采样
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)

            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, 2, 0),
                nn.BatchNorm2d(out_channels)
            )

        else:
            # 不需要下采样
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)

            # 对输入不做处理, nn.Sequential()里面啥都没有, 相当于残差连接那个分支啥都没有
            self.shortcut = nn.Sequential()

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()

        self.relu3 = nn.ReLU()

    def forward(self, x):

        # 对输入进行处理
        shortcut = self.shortcut(x)

        # conv1
        x = self.conv1(x)
        x = self.relu1(self.bn1(x))

        # conv2
        x = self.conv2(x)
        x = self.relu2(self.bn2(x))

        # 残差连接
        x = x + shortcut

        # 激活
        x = self.relu3(x)

        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ResBlock 实例化
resblock = ResBlock(True, 64, 64).to(device)

# 打印模型参数, 可视化
summary(resblock, (64, 56, 56))

4. 比较

  • 首先看看模型参数量, 经过随机输入一个张量,查看模型结构,两种结果一样, 但是仔细比较上述两种残差单元结构代码实现方式,别人写的复用用也高, 而自己写的相同的模块没有进行复用,自己写的还是很多不足。
    在这里插入图片描述

三、整个网络实现

class Resnet18(nn.Module):
    """
    搭建一个简单的残差网络: RESNET 18
    输入: 224*224*3
    输出: 1000类
    """
    def __init__(self, num_classes):
        super(Resnet18, self).__init__()
        # Layer 0
        self.layer_0 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # Layer 1
        self.layer_1 = nn.Sequential(
            # 不需下采样
            ResBlock(False, 64, 64)

            # 不需下采样
            ResBlock(False, 64, 64)
        )

        self.layer_2 = nn.Sequential(
            # 下采样
            ResBlock(True, 64, 128),

            # 无需下采样
            ResBlock(False, 128, 128)
        )

        self.layer_3 = nn.Sequential(
            # 下采样
            ResBlock(True, 128, 256),

            # 无需下采样
            ResBlock(False, 256, 256)
        )

        self.layer_4 = nn.Sequential(
            # 下采样
            ResBlock(True, 256, 512),

            # 无需下采样
            ResBlock(False, 512, 512)
        )

        # AdaptiveAvgPool2d
        self.app = nn.AdaptiveAvgPool2d(1)

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(512, num_classes)

    def forward(self, x):

        x = self.layer_0(x)
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.app(x)
        x = self.flatten(x)
        x = self.linear(x)

        return x
  • 可视化模型参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ResBlock 实例化
resnet = Resnet18(10).to(device)

# 打印模型参数, 可视化
summary(resnet, (3, 224, 224))

猜你喜欢

转载自blog.csdn.net/m0_60890175/article/details/130377937