Reproduction of resnet18 network structure code

Reproduction of resnet18 network structure code

  • More learning project routes: https://github.com/xiaoaleiBLUE


foreword

This article is mainly to reproduce the code of the resnet18 network structure (pytorch framework), make a simple drawing of the network structure, and compare and compare the resnet18 network structure written by others, and find that the readability of the code written by myself still needs improve.


1. Resnet18 network structure

insert image description here

Second, the specific analysis of the structure

1. Architecture Analysis

insert image description here
insert image description here

  • In fact, we found that it is a module for multiplexing, but the input and output sizes are inconsistent.
  • A minimum unit without residual connection
    insert image description here
  • A minimal unit with residual connections
    insert image description here

2. Write your own residual unit

class Resblock(nn.Module):
    def __init__(self, down_sample, in_channels, out_channels):
        super(Resblock, self).__init__()

        self.down_sample = down_sample

        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.bn_1 = nn.BatchNorm2d(out_channels)
        self.relu_1 = nn.ReLU()

        self.conv_2 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
        self.bn_2 = nn.BatchNorm2d(out_channels)
        self.relu_2 = nn.ReLU()

        self.shortcut = nn.Conv2d(in_channels, out_channels, 1, 2, 0)
        self.shortcut_bn = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()

        self.relu = nn.ReLU()

    def forward(self, x):
        if self.down_sample:
            shortcut = self.shortcut_bn(self.shortcut(x))
            x = self.relu_2(self.bn_2(self.conv_2(x)))
            x = self.relu2(self.bn2(self.conv2(x)))

            x = x + shortcut

        else:
            shortcut = x
            x = self.relu_1(self.bn_1(self.conv_1(x)))
            x = self.relu2(self.bn2(self.conv2(x)))

            x = x + shortcut

        x = self.relu(x)
        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Resblock 实例化
resblock = Resblock(True, 64, 64).to(device)

# 打印模型参数, 可视化
summary(resblock, (64, 56, 56))

3. I saw the residual unit written by others on the Internet

class ResBlock(nn.Module):
    def __init__(self, down_sample, in_channels, out_channels):
        super(ResBlock, self).__init__()

        self.down_sample = down_sample

        if self.down_sample:
            # 需要下采样
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)

            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, 2, 0),
                nn.BatchNorm2d(out_channels)
            )

        else:
            # 不需要下采样
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)

            # 对输入不做处理, nn.Sequential()里面啥都没有, 相当于残差连接那个分支啥都没有
            self.shortcut = nn.Sequential()

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()

        self.relu3 = nn.ReLU()

    def forward(self, x):

        # 对输入进行处理
        shortcut = self.shortcut(x)

        # conv1
        x = self.conv1(x)
        x = self.relu1(self.bn1(x))

        # conv2
        x = self.conv2(x)
        x = self.relu2(self.bn2(x))

        # 残差连接
        x = x + shortcut

        # 激活
        x = self.relu3(x)

        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ResBlock 实例化
resblock = ResBlock(True, 64, 64).to(device)

# 打印模型参数, 可视化
summary(resblock, (64, 56, 56))

4. Compare

  • First look at the amount of model parameters. After randomly inputting a tensor and checking the model structure, the two results are the same, but carefully comparing the above two implementation methods of the residual unit structure code, the reuse of others' writing is also high, while the self-written one The same modules are not reused, and there are still many deficiencies in writing them yourself.
    insert image description here

3. Realization of the entire network

class Resnet18(nn.Module):
    """
    搭建一个简单的残差网络: RESNET 18
    输入: 224*224*3
    输出: 1000类
    """
    def __init__(self, num_classes):
        super(Resnet18, self).__init__()
        # Layer 0
        self.layer_0 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # Layer 1
        self.layer_1 = nn.Sequential(
            # 不需下采样
            ResBlock(False, 64, 64)

            # 不需下采样
            ResBlock(False, 64, 64)
        )

        self.layer_2 = nn.Sequential(
            # 下采样
            ResBlock(True, 64, 128),

            # 无需下采样
            ResBlock(False, 128, 128)
        )

        self.layer_3 = nn.Sequential(
            # 下采样
            ResBlock(True, 128, 256),

            # 无需下采样
            ResBlock(False, 256, 256)
        )

        self.layer_4 = nn.Sequential(
            # 下采样
            ResBlock(True, 256, 512),

            # 无需下采样
            ResBlock(False, 512, 512)
        )

        # AdaptiveAvgPool2d
        self.app = nn.AdaptiveAvgPool2d(1)

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(512, num_classes)

    def forward(self, x):

        x = self.layer_0(x)
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.app(x)
        x = self.flatten(x)
        x = self.linear(x)

        return x
  • Visualize model parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ResBlock 实例化
resnet = Resnet18(10).to(device)

# 打印模型参数, 可视化
summary(resnet, (3, 224, 224))

Guess you like

Origin blog.csdn.net/m0_60890175/article/details/130377937