Pytorch - Use Resnet that comes with pytorch as the backbone of the network

When using Pytorch to build your own neural network framework, you often need to use torchvision.modelsthe model built in Pytorch as the backbone for feature extraction, and then build a more complex network on this basis.

Here, take the built-in Resnet18 in Pytorch as an example, how to use it as a Backbone layer, see the following sample code

# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvision

class Resnet18Backbone(nn.Module):
    def __init__(self):
        super(Resnet18Backbone, self).__init__()

        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential()

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)

        return x

Using the above code, if the dimension of the input Tensor is [1,3,244,244], fowwardthe dimension of the output Tensor is [1,512,1,1]. If we need the dimension of the output Tensor to be [1,512], the squeezecorresponding dimension is needed. The modified code show as below

# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvision

class Resnet18Backbone(nn.Module):
    def __init__(self):
        super(Resnet18Backbone, self).__init__()

        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential()

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = x.squeeze(2).squeeze(2)

        return x

Well, the above code Resnet18Backbonecan be used as a layer in the network. Here, the Adaptive Average Pooling layer of ResNet is used as the output layer of the backbone. If we only need the previous convolution layer as the output layer, you can refer to the following code.

For example, if we want to use ResNet18's Adaptive Average Pooling as the output layer of backbone, we can write like this,

# backbone
        if backbone_name == 'resnet_18':
            resnet_net = torchvision.models.resnet18(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 512
        elif backbone_name == 'resnet_34':
            resnet_net = torchvision.models.resnet34(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 512
        elif backbone_name == 'resnet_50':
            resnet_net = torchvision.models.resnet50(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048
        elif backbone_name == 'resnet_101':
            resnet_net = torchvision.models.resnet101(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048
        elif backbone_name == 'resnet_152':
            resnet_net = torchvision.models.resnet152(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048
        elif backbone_name == 'resnet_50_modified_stride_1':
            resnet_net = resnet50(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048
        elif backbone_name == 'resnext101_32x8d':
            resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048

If we just need the previous convolutional layer as the backbone, we can write

# backbone
        if backbone_name == 'resnet_18':
            resnet_net = torchvision.models.resnet18(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_34':
            resnet_net = torchvision.models.resnet34(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_50':
            resnet_net = torchvision.models.resnet50(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_101':
            resnet_net = torchvision.models.resnet101(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_152':
            resnet_net = torchvision.models.resnet152(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_50_modified_stride_1':
            resnet_net = resnet50(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnext101_32x8d':
            resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

reference link

If you are interested, you can visit my personal website: https://www.stubbornhuang.com/

Guess you like

Origin blog.csdn.net/HW140701/article/details/128623871