[Image Classification] [Deep Learning] [Pytorch Version] Detailed Explanation of Inception-ResNet Model Algorithm

[Image Classification] [Deep Learning] [Pytorch Version] Detailed Explanation of Inception-ResNet Model Algorithm


Preface

GoogLeNet (Inception-ResNet) was developed by Szegedy, Christian and others from Google in "Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning [AAAI-2017]" [Paper Address] The improved model proposed in the article was inspired by ResNet [Reference] in deep networks For better performance impact, the paper adds the residual connection to the Inception structure to form two Inception-ResNet versions of the network. It replaces the pooling layer part of the original Inception block with the residual connection, and turns the splicing into a summation. Together, the training speed of Inception is improved.

Because InceptionV4, Inception-Resnet-v1 and Inception-Resnet-v2 come from the same paper, most readers have misunderstandings about InceptionV4, thinking that it is a combination of the Inception module and residual learning. In fact, InceptionV4 does not use the idea of ​​residual learning. It basically continues the structure of Inception v2/v3. Only Inception-Resnet-v1 and Inception-Resnet-v2 are the combination of Inception module and residual learning.


Inception-ResNet explanation

The core idea of ​​Inception-ResNet is to integrate the Inception module and the ResNet module to take advantage of their respective advantages. The Inception module captures multi-scale features by parallelizing multiple convolution kernels of different sizes, while the ResNet module solves the gradient disappearance and gradient explosion problems in deep networks through residual connections, helping to better train deep models. Inception-ResNet uses an Inception module similar to InceptionV4 [Reference], and introduces ResNet's residual connection into it. In this way, each Inception module in the network contains two branches: one is a regular Inception structure, and the other is an Inception structure containing residual connections. This design allows the model to better learn feature representations and propagate gradients more efficiently during training.

Inception-ResNet-V1

Inception-ResNet-v1: A structure with the same computational loss as InceptionV3 [Reference].

  1. Stem structure: The Stem structure of Inception-ResNet-V1 is similar to the network layer before the Inception structure group in the previous InceptionV3 network.

    All convolutions not marked with V indicate that the padding method is "SAME Padding", and the input and output dimensions are consistent; those marked with V indicate that the padding method is "VALID Padding", and the output dimension depends on the specific situation.

  2. Inception-resnet-A structure: A variant of the Inception-A structure in the InceptionV4 network. The purpose of 1×1 convolution is to maintain the feature map of the main branch and shortcut branch. The shape remains exactly the same.

    The Inception-resnet structure residual connection replaces the pooling layer in Inception, and the residual connection addition operation replaces the splicing operation in the original Inception block.

  3. Inception-resnet-B structure: A variant of the Inception-B structure in the InceptionV4 network, the purpose of 1×1 convolution is to maintain the feature map of the main branch and shortcut branch The shape remains exactly the same.

  4. Inception-resnet-C structure: A variant of the Inception-C structure in the InceptionV4 network. The purpose of 1×1 convolution is to maintain the feature map of the main branch and shortcut branch. The shape remains exactly the same.

  5. Redution-A structure: It is consistent with the Reduction-A structure in the InceptionV4 network. The difference lies in the number of convolution kernels.

    k and l represent the number of convolutions, and the reduction-A structures k and l of different network structures are different.

  6. Redution-B structure:
    .

Inception-ResNet-V2

Inception-ResNet-v2: This is a structure with the same computational loss as InceptionV4, but the training speed is faster than pure Inception-v4.
The overall framework of Inception-ResNet-v2 is consistent with that of Inception-ResNet-v1. Except that the stem structure of Inception-ResNet-v2 is the same as that of Inception V4, the other structures of Inception-ResNet-v2 are the same. Similar to Inception-ResNet-v1, except that the number of convolutions Inception-ResNet-v2 is larger.

  1. Stem structure: The stem structure of Inception-ResNet-v2 is the same as that of Inception V4.
  2. Inception-resnet-A structure: A variant of the Inception-A structure in the InceptionV4 network. The purpose of 1×1 convolution is to maintain the feature map of the main branch and shortcut branch. The shape remains exactly the same.
  3. Inception-resnet-B structure: A variant of the Inception-B structure in the InceptionV4 network, the purpose of 1×1 convolution is to maintain the feature map of the main branch and shortcut branch The shape remains exactly the same.
  4. Inception-resnet-C structure: A variant of the Inception-C structure in the InceptionV4 network. The purpose of 1×1 convolution is to maintain the feature map of the main branch and shortcut branch. The shape remains exactly the same.
  5. Redution-A structure: It is consistent with the Reduction-A structure in the InceptionV4 network. The difference lies in the number of convolution kernels.

    k and l represent the number of convolutions, and the reduction-A structures k and l of different network structures are different.

    1. Reduction-B structure:

Scaling of the Residuals

If the number of convolution kernels in a single network layer is too large (more than 1000), the residual network will begin to become unstable, and the network will begin to fail early in the training process—after tens of thousands of trainings, the layers before the average pooling layer will begin to only Output 0. This situation cannot be avoided by reducing the learning rate and adding additional BN layers. Therefore, scaling the output of the residual block before adding the shortcut branch to the output of the current residual block can stabilize training.

Typically, the residual scaling factor is set between 0.1 and 0.3 to scale the residual block output. Even though scaling is not completely necessary, it does not seem to affect the final accuracy, but scaling can benefit the stability of the training.

The overall model structure of Inception-ResNet

The following figure is a detailed schematic diagram of the Inception-ResNet-V1 model structure given in the original paper:

The following figure is a detailed schematic diagram of the Inception-ResNet-V2 model structure given in the original paper: Detailed diagram of:

Readers should note that part of the number of Inception-ResNet-V2 channels marked in the original paper is wrong and cannot be matched when writing the code.

The overall structure of the two versions is the same, but the specific Stem, Inception blocks, and Reduction blocks are slightly different.
Inception-ResNet-V1 and Inception-ResNet-V2 are divided into two parts in image classification: backbone part: It is mainly composed of Inception-resnet module, Stem module and pooling layer (aggregation layer). The classifier part: consists of fully connected layer.


GoogLeNet(Inception-ResNet) Pytorch code

Inception-ResNet-V1

Convolutional layer group: Convolutional layer + BN layer + activation function

# 卷积组: Conv2d+BN+ReLU
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

Stem module: Convolutional layer group + pooling layer

# Stem:BasicConv2d+MaxPool2d
class Stem(nn.Module):
    def __init__(self, in_channels):
        super(Stem, self).__init__()

        # conv3x3(32 stride2 valid)
        self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)
        # conv3*3(32 valid)
        self.conv2 = BasicConv2d(32, 32, kernel_size=3)
        # conv3*3(64)
        self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)

        # maxpool3*3(stride2 valid)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)

        # conv1*1(80)
        self.conv5 = BasicConv2d(64, 80, kernel_size=1)
        # conv3*3(192 valid)
        self.conv6 = BasicConv2d(80, 192, kernel_size=1)

        # conv3*3(256 stride2 valid)
        self.conv7 = BasicConv2d(192, 256, kernel_size=3, stride=2)

    def forward(self, x):
        x = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))
        x = self.conv7(self.conv6(self.conv5(x)))
        return x

Inception_ResNet-A module: Convolutional layer group + pooling layer

# Inception_ResNet_A:BasicConv2d+MaxPool2d
class Inception_ResNet_A(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0):
        super(Inception_ResNet_A, self).__init__()
        # 缩减指数
        self.scale = scale
        # conv1*1(32)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(32)+conv3*3(32)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, 1),
            BasicConv2d(ch3x3red, ch3x3, 3, stride=1, padding=1)
        )
        # conv1*1(32)+conv3*3(32)+conv3*3(32)
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3redX2, 1),
            BasicConv2d(ch3x3redX2, ch3x3X2_1, 3, stride=1, padding=1),
            BasicConv2d(ch3x3X2_1, ch3x3X2_2, 3, stride=1, padding=1)
        )
        # conv1*1(256)
        self.conv = BasicConv2d(ch1x1+ch3x3+ch3x3X2_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        # 拼接
        x_res = torch.cat((x0, x1, x2), dim=1)
        x_res = self.conv(x_res)
        return self.relu(x + self.scale * x_res)

Inception_ResNet-B module: Convolutional layer group + pooling layer

# Inception_ResNet_B:BasicConv2d+MaxPool2d
class Inception_ResNet_B(nn.Module):
    def __init__(self, in_channels, ch1x1, ch_red, ch_1, ch_2, ch1x1ext, scale=1.0):
        super(Inception_ResNet_B, self).__init__()
        # 缩减指数
        self.scale = scale
        # conv1*1(128)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(128)+conv1*7(128)+conv1*7(128)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch_red, 1),
            BasicConv2d(ch_red, ch_1, (1, 7), stride=1, padding=(0, 3)),
            BasicConv2d(ch_1, ch_2, (7, 1), stride=1, padding=(3, 0))
        )
        # conv1*1(896)
        self.conv = BasicConv2d(ch1x1+ch_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        # 拼接
        x_res = torch.cat((x0, x1), dim=1)
        x_res = self.conv(x_res)
        return self.relu(x + self.scale * x_res)

Inception_ResNet-C module: Convolutional layer group + pooling layer

# Inception_ResNet_C:BasicConv2d+MaxPool2d
class Inception_ResNet_C(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext,  scale=1.0, activation=True):
        super(Inception_ResNet_C, self).__init__()
        # 缩减指数
        self.scale = scale
        # 是否激活
        self.activation = activation
        # conv1*1(192)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(192)+conv1*3(192)+conv3*1(192)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3redX2, 1),
            BasicConv2d(ch3x3redX2, ch3x3X2_1, (1, 3), stride=1, padding=(0, 1)),
            BasicConv2d(ch3x3X2_1, ch3x3X2_2, (3, 1), stride=1, padding=(1, 0))
        )
        # conv1*1(1792)
        self.conv = BasicConv2d(ch1x1+ch3x3X2_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        # 拼接
        x_res = torch.cat((x0, x1), dim=1)
        x_res = self.conv(x_res)
        if self.activation:
            return self.relu(x + self.scale * x_res)
        return x + self.scale * x_res

reductionA module: Convolutional layer group + pooling layer

# redutionA:BasicConv2d+MaxPool2d
class redutionA(nn.Module):
    def __init__(self, in_channels, k, l, m, n):
        super(redutionA, self).__init__()
        # conv3*3(n stride2 valid)
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channels, n, kernel_size=3, stride=2),
        )
        # conv1*1(k)+conv3*3(l)+conv3*3(m stride2 valid)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, k, kernel_size=1),
            BasicConv2d(k, l, kernel_size=3, padding=1),
            BasicConv2d(l, m, kernel_size=3, stride=2)
        )
        # maxpool3*3(stride2 valid)
        self.branch3 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2))

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        # 拼接
        outputs = [branch1, branch2, branch3]
        return torch.cat(outputs, 1)

reductionB module: Convolutional layer group + pooling layer

# redutionB:BasicConv2d+MaxPool2d
class redutionB(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3_1, ch3x3_2, ch3x3_3, ch3x3_4):
        super(redutionB, self).__init__()
        # conv1*1(256)+conv3x3(384 stride2 valid)
        self.branch_0 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_1, 3, stride=2, padding=0)
        )
        # conv1*1(256)+conv3x3(256 stride2 valid)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_2, 3, stride=2, padding=0),
        )
        # conv1*1(256)+conv3x3(256)+conv3x3(256 stride2 valid)
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_3, 3, stride=1, padding=1),
            BasicConv2d(ch3x3_3, ch3x3_4, 3, stride=2, padding=0)
        )
        # maxpool3*3(stride2 valid)
        self.branch_3 = nn.MaxPool2d(3, stride=2, padding=0)

    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        x3 = self.branch_3(x)
        return torch.cat((x0, x1, x2, x3), dim=1)

Inception-ResNet-V2

Inception-ResNet-V2 except Stem, other modules are structurally consistent with Inception-ResNet-V1.
Convolutional layer group: Convolutional layer + BN layer + activation function

# 卷积组: Conv2d+BN+ReLU
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

Stem module: Convolutional layer group + pooling layer

# Stem:BasicConv2d+MaxPool2d
class Stem(nn.Module):
    def __init__(self, in_channels):
        super(Stem, self).__init__()
        # conv3*3(32 stride2 valid)
        self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)
        # conv3*3(32 valid)
        self.conv2 = BasicConv2d(32, 32, kernel_size=3)
        # conv3*3(64)
        self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        # maxpool3*3(stride2 valid) & conv3*3(96 stride2 valid)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv4 = BasicConv2d(64, 96, kernel_size=3, stride=2)

        # conv1*1(64)+conv3*3(96 valid)
        self.conv5_1_1 = BasicConv2d(160, 64, kernel_size=1)
        self.conv5_1_2 = BasicConv2d(64, 96, kernel_size=3)
        # conv1*1(64)+conv7*1(64)+conv1*7(64)+conv3*3(96 valid)
        self.conv5_2_1 = BasicConv2d(160, 64, kernel_size=1)
        self.conv5_2_2 = BasicConv2d(64, 64, kernel_size=(7, 1), padding=(3, 0))
        self.conv5_2_3 = BasicConv2d(64, 64, kernel_size=(1, 7), padding=(0, 3))
        self.conv5_2_4 = BasicConv2d(64, 96, kernel_size=3)

        # conv3*3(192 valid) & maxpool3*3(stride2 valid)
        self.conv6 = BasicConv2d(192, 192, kernel_size=3, stride=2)
        self.maxpool6 = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        x1_1 = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))
        x1_2 = self.conv4(self.conv3(self.conv2(self.conv1(x))))
        x1 = torch.cat([x1_1, x1_2], 1)

        x2_1 = self.conv5_1_2(self.conv5_1_1(x1))
        x2_2 = self.conv5_2_4(self.conv5_2_3(self.conv5_2_2(self.conv5_2_1(x1))))
        x2 = torch.cat([x2_1, x2_2], 1)

        x3_1 = self.conv6(x2)
        x3_2 = self.maxpool6(x2)
        x3 = torch.cat([x3_1, x3_2], 1)
        return x3

Inception_ResNet-A module: Convolutional layer group + pooling layer

# Inception_ResNet_A:BasicConv2d+MaxPool2d
class Inception_ResNet_A(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0):
        super(Inception_ResNet_A, self).__init__()
        # 缩减指数
        self.scale = scale
        # conv1*1(32)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(32)+conv3*3(32)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, 1),
            BasicConv2d(ch3x3red, ch3x3, 3, stride=1, padding=1)
        )
        # conv1*1(32)+conv3*3(48)+conv3*3(64)
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3redX2, 1),
            BasicConv2d(ch3x3redX2, ch3x3X2_1, 3, stride=1, padding=1),
            BasicConv2d(ch3x3X2_1, ch3x3X2_2, 3, stride=1, padding=1)
        )
        # conv1*1(384)
        self.conv = BasicConv2d(ch1x1+ch3x3+ch3x3X2_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        # 拼接
        x_res = torch.cat((x0, x1, x2), dim=1)
        x_res = self.conv(x_res)
        return self.relu(x + self.scale * x_res)

Inception_ResNet-B module: Convolutional layer group + pooling layer

# Inception_ResNet_B:BasicConv2d+MaxPool2d
class Inception_ResNet_B(nn.Module):
    def __init__(self, in_channels, ch1x1, ch_red, ch_1, ch_2, ch1x1ext, scale=1.0):
        super(Inception_ResNet_B, self).__init__()
        # 缩减指数
        self.scale = scale
        # conv1*1(192)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(128)+conv1*7(160)+conv1*7(192)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch_red, 1),
            BasicConv2d(ch_red, ch_1, (1, 7), stride=1, padding=(0, 3)),
            BasicConv2d(ch_1, ch_2, (7, 1), stride=1, padding=(3, 0))
        )
        # conv1*1(1154)
        self.conv = BasicConv2d(ch1x1+ch_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        # 拼接
        x_res = torch.cat((x0, x1), dim=1)
        x_res = self.conv(x_res)
        return self.relu(x + self.scale * x_res)

Inception_ResNet-C module: Convolutional layer group + pooling layer

# Inception_ResNet_C:BasicConv2d+MaxPool2d
class Inception_ResNet_C(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext,  scale=1.0, activation=True):
        super(Inception_ResNet_C, self).__init__()
        # 缩减指数
        self.scale = scale
        # 是否激活
        self.activation = activation
        # conv1*1(192)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(192)+conv1*3(224)+conv3*1(256)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3redX2, 1),
            BasicConv2d(ch3x3redX2, ch3x3X2_1, (1, 3), stride=1, padding=(0, 1)),
            BasicConv2d(ch3x3X2_1, ch3x3X2_2, (3, 1), stride=1, padding=(1, 0))
        )
        # conv1*1(2048)
        self.conv = BasicConv2d(ch1x1+ch3x3X2_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        # 拼接
        x_res = torch.cat((x0, x1), dim=1)
        x_res = self.conv(x_res)
        if self.activation:
            return self.relu(x + self.scale * x_res)
        return x + self.scale * x_res

reductionA module: Convolutional layer group + pooling layer

# redutionA:BasicConv2d+MaxPool2d
class redutionA(nn.Module):
    def __init__(self, in_channels, k, l, m, n):
        super(redutionA, self).__init__()
        # conv3*3(n stride2 valid)
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channels, n, kernel_size=3, stride=2),
        )
        # conv1*1(k)+conv3*3(l)+conv3*3(m stride2 valid)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, k, kernel_size=1),
            BasicConv2d(k, l, kernel_size=3, padding=1),
            BasicConv2d(l, m, kernel_size=3, stride=2)
        )
        # maxpool3*3(stride2 valid)
        self.branch3 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2))

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        # 拼接
        outputs = [branch1, branch2, branch3]
        return torch.cat(outputs, 1)

reductionB module: Convolutional layer group + pooling layer

# redutionB:BasicConv2d+MaxPool2d
class redutionB(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3_1, ch3x3_2, ch3x3_3, ch3x3_4):
        super(redutionB, self).__init__()
        # conv1*1(256)+conv3x3(384 stride2 valid)
        self.branch_0 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_1, 3, stride=2, padding=0)
        )
        # conv1*1(256)+conv3x3(288 stride2 valid)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_2, 3, stride=2, padding=0),
        )
        # conv1*1(256)+conv3x3(288)+conv3x3(320 stride2 valid)
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_3, 3, stride=1, padding=1),
            BasicConv2d(ch3x3_3, ch3x3_4, 3, stride=2, padding=0)
        )
        # maxpool3*3(stride2 valid)
        self.branch_3 = nn.MaxPool2d(3, stride=2, padding=0)

    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        x3 = self.branch_3(x)
        return torch.cat((x0, x1, x2, x3), dim=1)

Complete code

The input image size of Inception-ResNet is 299×299

Inception-ResNet-V1

import torch
import torch.nn as nn
from torchsummary import summary

# 卷积组: Conv2d+BN+ReLU
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# Stem:BasicConv2d+MaxPool2d
class Stem(nn.Module):
    def __init__(self, in_channels):
        super(Stem, self).__init__()

        # conv3x3(32 stride2 valid)
        self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)
        # conv3*3(32 valid)
        self.conv2 = BasicConv2d(32, 32, kernel_size=3)
        # conv3*3(64)
        self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)

        # maxpool3*3(stride2 valid)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)

        # conv1*1(80)
        self.conv5 = BasicConv2d(64, 80, kernel_size=1)
        # conv3*3(192 valid)
        self.conv6 = BasicConv2d(80, 192, kernel_size=1)

        # conv3*3(256 stride2 valid)
        self.conv7 = BasicConv2d(192, 256, kernel_size=3, stride=2)

    def forward(self, x):
        x = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))
        x = self.conv7(self.conv6(self.conv5(x)))
        return x

# Inception_ResNet_A:BasicConv2d+MaxPool2d
class Inception_ResNet_A(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0):
        super(Inception_ResNet_A, self).__init__()
        # 缩减指数
        self.scale = scale
        # conv1*1(32)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(32)+conv3*3(32)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, 1),
            BasicConv2d(ch3x3red, ch3x3, 3, stride=1, padding=1)
        )
        # conv1*1(32)+conv3*3(32)+conv3*3(32)
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3redX2, 1),
            BasicConv2d(ch3x3redX2, ch3x3X2_1, 3, stride=1, padding=1),
            BasicConv2d(ch3x3X2_1, ch3x3X2_2, 3, stride=1, padding=1)
        )
        # conv1*1(256)
        self.conv = BasicConv2d(ch1x1+ch3x3+ch3x3X2_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        # 拼接
        x_res = torch.cat((x0, x1, x2), dim=1)
        x_res = self.conv(x_res)
        return self.relu(x + self.scale * x_res)

# Inception_ResNet_B:BasicConv2d+MaxPool2d
class Inception_ResNet_B(nn.Module):
    def __init__(self, in_channels, ch1x1, ch_red, ch_1, ch_2, ch1x1ext, scale=1.0):
        super(Inception_ResNet_B, self).__init__()
        # 缩减指数
        self.scale = scale
        # conv1*1(128)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(128)+conv1*7(128)+conv1*7(128)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch_red, 1),
            BasicConv2d(ch_red, ch_1, (1, 7), stride=1, padding=(0, 3)),
            BasicConv2d(ch_1, ch_2, (7, 1), stride=1, padding=(3, 0))
        )
        # conv1*1(896)
        self.conv = BasicConv2d(ch1x1+ch_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        # 拼接
        x_res = torch.cat((x0, x1), dim=1)
        x_res = self.conv(x_res)
        return self.relu(x + self.scale * x_res)

# Inception_ResNet_C:BasicConv2d+MaxPool2d
class Inception_ResNet_C(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext,  scale=1.0, activation=True):
        super(Inception_ResNet_C, self).__init__()
        # 缩减指数
        self.scale = scale
        # 是否激活
        self.activation = activation
        # conv1*1(192)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(192)+conv1*3(192)+conv3*1(192)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3redX2, 1),
            BasicConv2d(ch3x3redX2, ch3x3X2_1, (1, 3), stride=1, padding=(0, 1)),
            BasicConv2d(ch3x3X2_1, ch3x3X2_2, (3, 1), stride=1, padding=(1, 0))
        )
        # conv1*1(1792)
        self.conv = BasicConv2d(ch1x1+ch3x3X2_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        # 拼接
        x_res = torch.cat((x0, x1), dim=1)
        x_res = self.conv(x_res)
        if self.activation:
            return self.relu(x + self.scale * x_res)
        return x + self.scale * x_res

# redutionA:BasicConv2d+MaxPool2d
class redutionA(nn.Module):
    def __init__(self, in_channels, k, l, m, n):
        super(redutionA, self).__init__()
        # conv3*3(n stride2 valid)
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channels, n, kernel_size=3, stride=2),
        )
        # conv1*1(k)+conv3*3(l)+conv3*3(m stride2 valid)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, k, kernel_size=1),
            BasicConv2d(k, l, kernel_size=3, padding=1),
            BasicConv2d(l, m, kernel_size=3, stride=2)
        )
        # maxpool3*3(stride2 valid)
        self.branch3 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2))

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        # 拼接
        outputs = [branch1, branch2, branch3]
        return torch.cat(outputs, 1)

# redutionB:BasicConv2d+MaxPool2d
class redutionB(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3_1, ch3x3_2, ch3x3_3, ch3x3_4):
        super(redutionB, self).__init__()
        # conv1*1(256)+conv3x3(384 stride2 valid)
        self.branch_0 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_1, 3, stride=2, padding=0)
        )
        # conv1*1(256)+conv3x3(256 stride2 valid)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_2, 3, stride=2, padding=0),
        )
        # conv1*1(256)+conv3x3(256)+conv3x3(256 stride2 valid)
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_3, 3, stride=1, padding=1),
            BasicConv2d(ch3x3_3, ch3x3_4, 3, stride=2, padding=0)
        )
        # maxpool3*3(stride2 valid)
        self.branch_3 = nn.MaxPool2d(3, stride=2, padding=0)

    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        x3 = self.branch_3(x)
        return torch.cat((x0, x1, x2, x3), dim=1)

class Inception_ResNetv1(nn.Module):
    def __init__(self, num_classes = 1000, k=192, l=192, m=256, n=384):
        super(Inception_ResNetv1, self).__init__()
        blocks = []
        blocks.append(Stem(3))
        for i in range(5):
            blocks.append(Inception_ResNet_A(256,32, 32, 32, 32, 32, 32, 256, 0.17))
        blocks.append(redutionA(256, k, l, m, n))
        for i in range(10):
            blocks.append(Inception_ResNet_B(896, 128, 128, 128, 128, 896, 0.10))
        blocks.append(redutionB(896,256, 384, 256, 256, 256))
        for i in range(4):
            blocks.append(Inception_ResNet_C(1792,192, 192, 192, 192, 1792, 0.20))
        blocks.append(Inception_ResNet_C(1792, 192, 192, 192, 192, 1792, activation=False))
        self.features = nn.Sequential(*blocks)
        self.conv = BasicConv2d(1792, 1536, 1)
        self.global_average_pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.8)
        self.linear = nn.Linear(1536, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.conv(x)
        x = self.global_average_pooling(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.linear(x)
        return x

if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Inception_ResNetv1().to(device)
    summary(model, input_size=(3, 229, 229))

summary can print the network structure and parameters, making it easy to view the built network structure.

Inception-ResNet-V2

import torch
import torch.nn as nn
from torchsummary import summary

# 卷积组: Conv2d+BN+ReLU
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# Stem:BasicConv2d+MaxPool2d
class Stem(nn.Module):
    def __init__(self, in_channels):
        super(Stem, self).__init__()
        # conv3*3(32 stride2 valid)
        self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)
        # conv3*3(32 valid)
        self.conv2 = BasicConv2d(32, 32, kernel_size=3)
        # conv3*3(64)
        self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        # maxpool3*3(stride2 valid) & conv3*3(96 stride2 valid)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv4 = BasicConv2d(64, 96, kernel_size=3, stride=2)

        # conv1*1(64)+conv3*3(96 valid)
        self.conv5_1_1 = BasicConv2d(160, 64, kernel_size=1)
        self.conv5_1_2 = BasicConv2d(64, 96, kernel_size=3)
        # conv1*1(64)+conv7*1(64)+conv1*7(64)+conv3*3(96 valid)
        self.conv5_2_1 = BasicConv2d(160, 64, kernel_size=1)
        self.conv5_2_2 = BasicConv2d(64, 64, kernel_size=(7, 1), padding=(3, 0))
        self.conv5_2_3 = BasicConv2d(64, 64, kernel_size=(1, 7), padding=(0, 3))
        self.conv5_2_4 = BasicConv2d(64, 96, kernel_size=3)

        # conv3*3(192 valid) & maxpool3*3(stride2 valid)
        self.conv6 = BasicConv2d(192, 192, kernel_size=3, stride=2)
        self.maxpool6 = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        x1_1 = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))
        x1_2 = self.conv4(self.conv3(self.conv2(self.conv1(x))))
        x1 = torch.cat([x1_1, x1_2], 1)

        x2_1 = self.conv5_1_2(self.conv5_1_1(x1))
        x2_2 = self.conv5_2_4(self.conv5_2_3(self.conv5_2_2(self.conv5_2_1(x1))))
        x2 = torch.cat([x2_1, x2_2], 1)

        x3_1 = self.conv6(x2)
        x3_2 = self.maxpool6(x2)
        x3 = torch.cat([x3_1, x3_2], 1)
        return x3

# Inception_ResNet_A:BasicConv2d+MaxPool2d
class Inception_ResNet_A(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0):
        super(Inception_ResNet_A, self).__init__()
        # 缩减指数
        self.scale = scale
        # conv1*1(32)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(32)+conv3*3(32)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, 1),
            BasicConv2d(ch3x3red, ch3x3, 3, stride=1, padding=1)
        )
        # conv1*1(32)+conv3*3(48)+conv3*3(64)
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3redX2, 1),
            BasicConv2d(ch3x3redX2, ch3x3X2_1, 3, stride=1, padding=1),
            BasicConv2d(ch3x3X2_1, ch3x3X2_2, 3, stride=1, padding=1)
        )
        # conv1*1(384)
        self.conv = BasicConv2d(ch1x1+ch3x3+ch3x3X2_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        # 拼接
        x_res = torch.cat((x0, x1, x2), dim=1)
        x_res = self.conv(x_res)
        return self.relu(x + self.scale * x_res)

# Inception_ResNet_B:BasicConv2d+MaxPool2d
class Inception_ResNet_B(nn.Module):
    def __init__(self, in_channels, ch1x1, ch_red, ch_1, ch_2, ch1x1ext, scale=1.0):
        super(Inception_ResNet_B, self).__init__()
        # 缩减指数
        self.scale = scale
        # conv1*1(192)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(128)+conv1*7(160)+conv1*7(192)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch_red, 1),
            BasicConv2d(ch_red, ch_1, (1, 7), stride=1, padding=(0, 3)),
            BasicConv2d(ch_1, ch_2, (7, 1), stride=1, padding=(3, 0))
        )
        # conv1*1(1154)
        self.conv = BasicConv2d(ch1x1+ch_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        # 拼接
        x_res = torch.cat((x0, x1), dim=1)
        x_res = self.conv(x_res)
        return self.relu(x + self.scale * x_res)

# Inception_ResNet_C:BasicConv2d+MaxPool2d
class Inception_ResNet_C(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext,  scale=1.0, activation=True):
        super(Inception_ResNet_C, self).__init__()
        # 缩减指数
        self.scale = scale
        # 是否激活
        self.activation = activation
        # conv1*1(192)
        self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)
        # conv1*1(192)+conv1*3(224)+conv3*1(256)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3redX2, 1),
            BasicConv2d(ch3x3redX2, ch3x3X2_1, (1, 3), stride=1, padding=(0, 1)),
            BasicConv2d(ch3x3X2_1, ch3x3X2_2, (3, 1), stride=1, padding=(1, 0))
        )
        # conv1*1(2048)
        self.conv = BasicConv2d(ch1x1+ch3x3X2_2, ch1x1ext, 1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        # 拼接
        x_res = torch.cat((x0, x1), dim=1)
        x_res = self.conv(x_res)
        if self.activation:
            return self.relu(x + self.scale * x_res)
        return x + self.scale * x_res

# redutionA:BasicConv2d+MaxPool2d
class redutionA(nn.Module):
    def __init__(self, in_channels, k, l, m, n):
        super(redutionA, self).__init__()
        # conv3*3(n stride2 valid)
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channels, n, kernel_size=3, stride=2),
        )
        # conv1*1(k)+conv3*3(l)+conv3*3(m stride2 valid)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, k, kernel_size=1),
            BasicConv2d(k, l, kernel_size=3, padding=1),
            BasicConv2d(l, m, kernel_size=3, stride=2)
        )
        # maxpool3*3(stride2 valid)
        self.branch3 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2))

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        # 拼接
        outputs = [branch1, branch2, branch3]
        return torch.cat(outputs, 1)

# redutionB:BasicConv2d+MaxPool2d
class redutionB(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3_1, ch3x3_2, ch3x3_3, ch3x3_4):
        super(redutionB, self).__init__()
        # conv1*1(256)+conv3x3(384 stride2 valid)
        self.branch_0 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_1, 3, stride=2, padding=0)
        )
        # conv1*1(256)+conv3x3(288 stride2 valid)
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_2, 3, stride=2, padding=0),
        )
        # conv1*1(256)+conv3x3(288)+conv3x3(320 stride2 valid)
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, ch1x1, 1),
            BasicConv2d(ch1x1, ch3x3_3, 3, stride=1, padding=1),
            BasicConv2d(ch3x3_3, ch3x3_4, 3, stride=2, padding=0)
        )
        # maxpool3*3(stride2 valid)
        self.branch_3 = nn.MaxPool2d(3, stride=2, padding=0)

    def forward(self, x):
        x0 = self.branch_0(x)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        x3 = self.branch_3(x)
        return torch.cat((x0, x1, x2, x3), dim=1)

class Inception_ResNetv2(nn.Module):
    def __init__(self, num_classes = 1000, k=256, l=256, m=384, n=384):
        super(Inception_ResNetv2, self).__init__()
        blocks = []
        blocks.append(Stem(3))
        for i in range(5):
            blocks.append(Inception_ResNet_A(384,32, 32, 32, 32, 48, 64, 384, 0.17))
        blocks.append(redutionA(384, k, l, m, n))
        for i in range(10):
            blocks.append(Inception_ResNet_B(1152, 192, 128, 160, 192, 1152, 0.10))
        blocks.append(redutionB(1152, 256, 384, 288, 288, 320))
        for i in range(4):
            blocks.append(Inception_ResNet_C(2144,192, 192, 224, 256, 2144, 0.20))
        blocks.append(Inception_ResNet_C(2144, 192, 192, 224, 256, 2144, activation=False))
        self.features = nn.Sequential(*blocks)
        self.conv = BasicConv2d(2144, 1536, 1)
        self.global_average_pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.8)
        self.linear = nn.Linear(1536, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.conv(x)
        x = self.global_average_pooling(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.linear(x)
        return x

if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Inception_ResNetv2().to(device)
    summary(model, input_size=(3, 229, 229))

summary can print the network structure and parameters, making it easy to view the built network structure.


Summarize

It introduces the function and process of Inception-ResNet combining Inception and ResNet as simply and in detail as possible, and explains the structure and pytorch code of the Inception-ResNet model.

Guess you like

Origin blog.csdn.net/yangyu0515/article/details/134513904