Pytorch — используйте Resnet, поставляемый с pytorch, в качестве основы сети.

При использовании Pytorch для создания собственной среды нейронной сети вам часто необходимо использовать torchvision.modelsмодель, построенную в Pytorch, в качестве основы для извлечения признаков, а затем строить на этой основе более сложную сеть.

Здесь в качестве примера мы возьмем встроенный Resnet18 в Pytorch. Как использовать его в качестве уровня Backbone? См. следующий пример кода.

# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvision

class Resnet18Backbone(nn.Module):
    def __init__(self):
        super(Resnet18Backbone, self).__init__()

        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential()

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)

        return x

Используя приведенный выше код, если размерность входного тензора равна [1,3,244,244], fowwardразмерность выходного тензора равна [1,512,1,1]. Если нам нужно, чтобы размерность выходного тензора была [1,512], squeezeнеобходим соответствующий размер.Измененный код показан ниже.

# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvision

class Resnet18Backbone(nn.Module):
    def __init__(self):
        super(Resnet18Backbone, self).__init__()

        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential()

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = x.squeeze(2).squeeze(2)

        return x

Что ж, приведенный выше код Resnet18Backboneможно использовать в качестве слоя в сети. Здесь слой адаптивного среднего пула ResNet используется в качестве выходного слоя магистральной сети. Если нам нужен только предыдущий слой свертки в качестве выходного слоя, вы можете обратиться к следующему коду.

Например, если мы хотим использовать адаптивный средний пул ResNet18 в качестве выходного уровня магистрали, мы можем написать так:

# backbone
        if backbone_name == 'resnet_18':
            resnet_net = torchvision.models.resnet18(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 512
        elif backbone_name == 'resnet_34':
            resnet_net = torchvision.models.resnet34(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 512
        elif backbone_name == 'resnet_50':
            resnet_net = torchvision.models.resnet50(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048
        elif backbone_name == 'resnet_101':
            resnet_net = torchvision.models.resnet101(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048
        elif backbone_name == 'resnet_152':
            resnet_net = torchvision.models.resnet152(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048
        elif backbone_name == 'resnet_50_modified_stride_1':
            resnet_net = resnet50(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048
        elif backbone_name == 'resnext101_32x8d':
            resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)
            modules = list(resnet_net.children())[:-1]
            backbone = nn.Sequential(*modules)
            backbone.out_channels = 2048

Если нам нужен только предыдущий сверточный слой в качестве основы, мы можем написать так

# backbone
        if backbone_name == 'resnet_18':
            resnet_net = torchvision.models.resnet18(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_34':
            resnet_net = torchvision.models.resnet34(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_50':
            resnet_net = torchvision.models.resnet50(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_101':
            resnet_net = torchvision.models.resnet101(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_152':
            resnet_net = torchvision.models.resnet152(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnet_50_modified_stride_1':
            resnet_net = resnet50(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

        elif backbone_name == 'resnext101_32x8d':
            resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)
            modules = list(resnet_net.children())[:-2]
            backbone = nn.Sequential(*modules)

Справочная ссылка

Если вам интересно, вы можете посетить мой личный сайт: https://www.stubbornhuang.com/

Supongo que te gusta

Origin blog.csdn.net/HW140701/article/details/128623871
Recomendado
Clasificación