Visualización del proceso de extracción de características de la red Resnet

Cuando estamos entrenando imágenes, ¿queremos ver la extracción de cada mapa de características durante la extracción específica? Después de buscar mucho, el trabajo duro finalmente vale la pena. Lo encontré y modifiqué el código: 

codigo resnet:

import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
import math
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

class GhostModule(nn.Module):
    def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
        super(GhostModule, self).__init__()
        self.oup = oup
        init_channels = math.ceil(oup / ratio)
        new_channels = init_channels*(ratio-1)

        self.primary_conv = nn.Sequential(
            nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
            nn.BatchNorm2d(init_channels),
            nn.ReLU(inplace=True) if relu else nn.Sequential(),
        )

        self.cheap_operation = nn.Sequential(
            nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
            nn.BatchNorm2d(new_channels),
            nn.ReLU(inplace=True) if relu else nn.Sequential(),
        )

    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_operation(x1)
        out = torch.cat([x1,x2], dim=1)
        return out[:,:self.oup,:,:]


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)

        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)

        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]

        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
            
        self.block = block
        self.groups = groups
        self.base_width = width_per_group

        # 224,224,3 -> 112,112,64
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)

        # 112,112,64 -> 56,56,64
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # 56,56,64 -> 56,56,256
        self.layer1 = self._make_layer(block, 64, layers[0])

        # 56,56,256 -> 28,28,512
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])

        # 28,28,512 -> 14,14,1024
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])

        # 14,14,1024 -> 7,7,2048
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])

        # 7,7,2048 -> 2048
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # 2048 -> num_classes
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        # Conv_block
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            # identity_block
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
    
    def freeze_backbone(self):
        backbone = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4]
        for module in backbone:
            for param in module.parameters():
                param.requires_grad = False

    def Unfreeze_backbone(self):
        backbone = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4]
        for module in backbone:
            for param in module.parameters():
                param.requires_grad = True

def resnet18(pretrained=False, progress=True, num_classes=1000):
    model = ResNet(BasicBlock, [2, 2, 2, 2])
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['resnet18'], model_dir='./model_data',
                                              progress=progress)
        model.load_state_dict(state_dict)

    if num_classes!=1000:
        model.fc = nn.Linear(512 * model.block.expansion, num_classes)
    return model

def resnet34(pretrained=False, progress=True, num_classes=1000):
    model = ResNet(BasicBlock, [3, 4, 6, 3])
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['resnet34'], model_dir='./model_data',
                                              progress=progress)
        model.load_state_dict(state_dict)

    if num_classes!=1000:
        model.fc = nn.Linear(512 * model.block.expansion, num_classes)
    return model

def resnet50(pretrained=False, progress=True, num_classes=1000):
    model = ResNet(Bottleneck, [3, 4, 6, 3])
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['resnet50'], model_dir='./model_data',
                                              progress=progress)
        model.load_state_dict(state_dict)

    if num_classes!=1000:
        model.fc = nn.Linear(512 * model.block.expansion, num_classes)
    return model

def resnet101(pretrained=False, progress=True, num_classes=1000):
    model = ResNet(Bottleneck, [3, 4, 23, 3])
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['resnet101'], model_dir='./model_data',
                                              progress=progress)
        model.load_state_dict(state_dict)

    if num_classes!=1000:
        model.fc = nn.Linear(512 * model.block.expansion, num_classes)
    return model

def resnet152(pretrained=False, progress=True, num_classes=1000):
    model = ResNet(Bottleneck, [3, 8, 36, 3])
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['resnet152'], model_dir='./model_data',
                                              progress=progress)
        model.load_state_dict(state_dict)

    if num_classes!=1000:
        model.fc = nn.Linear(512 * model.block.expansion, num_classes)
    return model

Puede ver la estructura de la red:

#--------------------------------------------#
#   该部分代码只用于看网络结构,并非测试代码
#--------------------------------------------#
import torch
from thop import clever_format, profile
from torchsummary import summary

from nets import get_model_from_name
# from nets import resnet_cbam # 使用哪个引入哪个即可
if __name__ == "__main__":
    input_shape = [224, 224]

    num_classes = 3  #写自己的分类个数,如果是训练图像分割,要多分一个背景,比如猫狗两种,num_classes=2+1
    # backbone    = "mobilenetv2"
    backbone = "resnet50"
    device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model   = get_model_from_name[backbone](num_classes=num_classes, pretrained=False).to(device)
    
    summary(model, (3, input_shape[0], input_shape[1]))

    dummy_input     = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
    flops, params   = profile(model.to(device), (dummy_input, ), verbose=False)
    #--------------------------------------------------------#
    #   flops * 2是因为profile没有将卷积作为两个operations
    #   有些论文将卷积算乘法、加法两个operations。此时乘2
    #   有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
    #   本代码选择乘2,参考YOLOX。
    #--------------------------------------------------------#
    flops           = flops * 2
    flops, params   = clever_format([flops, params], "%.3f")
    print('Total GFLOPS: %s' % (flops))
    print('Total params: %s' % (params))

La estructura es la siguiente: (resnet50) 

Dos de los módulos residuales Conv Block se usan para cambiar la cantidad de canales, y Identity Block se usa para profundizar la red

4 circunvoluciones en el primer bloque residual, 3 en el segundo

(4+3+3) + (4+3+3+3) + (4+3+3+3+3+3+3) + (4+3+3)==52 

Además, hay una convolución 7*7==53

 -------------------------------------------------- --------------
lista          Capa (tipo) Parámetro de forma de salida #
=========================== =====================================
0   Conv2d-1 [-1, 64, 112, 112 ] 9408      BatchNorm2d-2 [-1, 64, 112, 112] 128               ReLU-3 [-1, 64, 112, 112] 0          MaxPool2d-4 [-1, 64, 56, 56] 0   Conv2d-5 [- 1, 64, 56, 56] 4096        BatchNorm2d-6 [-1, 64, 56, 56] 128               ReLU-7 [-1, 64, 56, 56] 0  2                 
 


     


Conv2d-8 [-1, 64, 56, 56] 36 864
       BatchNorm2d-9 [-1, 64, 56, 56] 128
             ReLU-10 [-1, 64, 56, 56] 0
3            Conv2d-11 [-1, 256, 56, 56] 16 384
      BatchNorm2d-12 [-1, 256, 56, 56] 512
 4          Conv2d-13 [-1, 256, 56, 56] 16 384
      BatchNorm2d-14 [-1, 256, 56, 56] 512
             ReLU-15 [-1, 256, 56, 56] 0
       Cuello de botella-16 [-1, 256, 56, 56] 0
       Conv2d-17 [-1, 64, 56, 56] 16 384       BatchNorm2d-18 [-1, 64, 56, 56] 128 

             ReLU-19 [-1, 64, 56, 56] 0
6           Conv2d-20 [-1, 64, 56, 56] 36 864
      BatchNorm2d-21 [-1, 64, 56, 56] 128
             ReLU-22 [-1, 64, 56, 56] 0
7          Conv2d-23 [-1, 256, 56, 56] 16 384
      BatchNorm2d-24 [-1, 256, 56, 56] 512
             ReLU-25 [-1, 256, 56, 56] 0
       Cuello de botella-26 [-1, 256, 56, 56] 0
8          Conv2d-27 [-1, 64, 56, 56] 16 384
      BatchNorm2d-28 [-1, 64, 56, 56] 128
             ReLU-29 [-1, 64, 56, 56] 0
9          Conv2d-30 [-1, 64, 56, 56] 36 864
      BatchNorm2d-31 [-1, 64, 56, 56] 128
             ReLU-32 [-1, 64, 56, 56] 0
10           Conv2d-33 [-1, 256, 56, 56] 16 384
      BatchNorm2d-34 [-1, 256, 56, 56] 512
             ReLU-35 [-1, 256, 56, 56] 0
       Cuello de botella-36 [-1, 256, 56, 56] 0
11        Conv2d-37 [-1, 128, 56, 56] 32 768       BatchNorm2d-38 [-1, 128, 56, 56] 256   

             ReLU-39 [-1, 128, 56, 56] 0
 12          Conv2d-40 [-1, 128, 28, 28] 147 456
      BatchNorm2d-41 [-1, 128, 28, 28] 256
             ReLU-42 [-1, 128, 28, 28] 0
13          Conv2d-43 [-1, 512, 28, 28] 65 536
      BatchNorm2d-44 [-1, 512, 28, 28] 1024
 14          Conv2d-45 [-1, 512, 28, 28] 131 072
      BatchNorm2d-46 [-1, 512, 28, 28] 1024
             ReLU-47 [-1, 512, 28, 28] 0
       Cuello de botella-48 [-1, 512, 28, 28] 0 15
           Conv2d-49 [-1, 128, 28, 28] 65 536
      BatchNorm2d-50 [-1, 128, 28, 28] 256
             ReLU-51 [-1, 128, 28, 28] 0
16           Conv2d-52 [-1, 128, 28, 28] 147 456
      BatchNorm2d-53 [-1, 128, 28, 28] 256
             ReLU-54 [-1, 128, 28, 28] 0
17           Conv2d-55 [-1, 512, 28, 28] 65 536
      BatchNorm2d-56 [-1, 512, 28, 28] 1,024
             ReLU-57 [-1, 512, 28, 28] 0
       Cuello de botella-58 [-1, 512, 28, 28] 0
18           Conv2d-59 [-1, 128, 28, 28] 65 536
      BatchNorm2d-60 [-1, 128, 28, 28] 256
             ReLU-61 [-1, 128, 28, 28] 0
 19          Conv2d-62 [-1, 128, 28, 28] 147 456
      BatchNorm2d-63 [-1, 128, 28, 28] 256
             ReLU-64 [-1, 128, 28, 28] 0
20           Conv2d-65 [-1, 512, 28, 28] 65 536
      BatchNorm2d-66 [-1, 512, 28, 28] 1024
             ReLU-67 [-1, 512, 28, 28] 0
       Cuello de botella-68 [-1, 512, 28, 28] 0
 21          Conv2d-69 [-1, 128, 28, 28] 65 536
      BatchNorm2d-70 [-1, 128, 28, 28] 256
             ReLU-71 [-1, 128, 28, 28] 0
22            Conv2d-72 [-1, 128, 28, 28] 147 456
      BatchNorm2d-73 [-1, 128, 28, 28] 256
             ReLU-74 [-1, 128, 28, 28] 0
23           Conv2d-75 [-1, 512, 28, 28] 65 536
      BatchNorm2d-76 [-1, 512, 28, 28] 1024
             ReLU-77 [-1, 512, 28, 28] 0
       Cuello de botella-78 [-1, 512, 28, 28] 0
24           Conv2d-79 [-1, 256, 28, 28] 131 072
      BatchNorm2d-80 [-1, 256, 28, 28] 512
             ReLU-81 [-1, 256, 28, 28] 0
25            Conv2d-82 [-1, 256, 14, 14] 589 824
      BatchNorm2d-83 [-1, 256, 14, 14] 512
             ReLU-84 [-1, 256, 14, 14] 0
26           Conv2d-85 [-1, 1024, 14, 14] 262 144
      BatchNorm2d-86 [-1, 1024, 14, 14] 2,048
 27           Conv2d-87 [-1, 1024, 14, 14] 524,288
      BatchNorm2d-88 [-1, 1024, 14, 14] 2,048
             ReLU-89 [-1, 1024, 14, 14] 0
       Cuello de botella-90 [-1, 1024, 14, 14] 0 28         Conv2d-91 [-1, 256, 14, 14] 262 144       BatchNorm2d-92 [-1, 256, 14, 14] 512              ReLU-93 [-1, 256, 14, 14] 0  29           Conv2d-94 [-1, 256, 14, 14] 589 824       BatchNorm2d-95 [-1, 256, 14, 14] 512              ReLU-96 [-1, 256, 14, 14] 0 30           Conv2d-97 [-1, 1024, 14, 14] 262,144       BatchNorm2d-98 [-1, 1024, 14, 14] 2,048
  







             ReLU-99 [-1, 1024, 14, 14] 0
      Cuello de botella-100 [-1, 1024, 14, 14] 0
 31          Conv2d-101 [-1, 256, 14, 14] 262 144
     BatchNorm2d-102 [-1, 256, 14, 14] 512
            ReLU-103 [-1, 256, 14, 14] 0
 32         Conv2d-104 [-1, 256, 14, 14] 589 824
     BatchNorm2d-105 [-1, 256, 14, 14] 512
            ReLU-106 [-1, 256, 14, 14] 0
 33         Conv2d-107 [-1, 1024, 14, 14] 262,144
     BatchNorm2d-108 [-1, 1024, 14, 14] 2,048
            ReLU-109 [-1, 1024, 14, 14] 0
      Cuello de botella-110 [-1, 1024, 14, 14] 0
 34          Conv2d-111 [-1, 256, 14, 14] 262,144
     BatchNorm2d-112 [-1, 256, 14, 14] 512
            ReLU-113 [-1, 256, 14, 14] 0
 35         Conv2d-114 [-1, 256, 14, 14] 589,824
     BatchNorm2d-115 [-1, 256, 14, 14] 512
            ReLU-116 [-1, 256, 14, 14] 0
 36          Conv2d-117 [-1, 1024, 14, 14] 262,144
     BatchNorm2d-118 [-1, 1024, 14, 14] 2,048
            ReLU-119 [-1, 1024, 14, 14] 0
      Cuello de botella-120 [-1, 1024, 14, 14] 0
37           Conv2d-121 [-1, 256, 14, 14] 262 144
     BatchNorm2d-122 [-1, 256, 14, 14] 512
            ReLU-123 [-1, 256, 14, 14] 0
38           Conv2d-124 [-1, 256, 14, 14] 589 824
     BatchNorm2d-125 [-1, 256, 14, 14] 512
            ReLU-126 [-1, 256, 14, 14] 0
 39          Conv2d-127 [-1, 1024, 14, 14] 262,144
     BatchNorm2d-128 [-1, 1024, 14, 14] 2,048
            ReLU-129 [-1, 1024, 14, 14] 0
      Cuello de botella-130 [-1, 1024, 14, 14] 0
40          Conv2d-131 [-1, 256, 14, 14] 262 144
     BatchNorm2d-132 [-1, 256, 14, 14] 512
            ReLU-133 [-1, 256, 14, 14] 0
 41          Conv2d-134 [-1, 256, 14, 14] 589,824
     BatchNorm2d-135 [-1, 256, 14, 14] 512
            ReLU-136 [-1, 256, 14, 14] 0
 42          Conv2d-137 [-1, 1024, 14, 14] 262,144
     BatchNorm2d-138 [-1, 1024, 14, 14] 2,048
            ReLU-139 [-1, 1024, 14, 14] 0
      Cuello de botella-140 [-1, 1024, 14, 14] 0 43      Conv2d-141 [-1, 512, 14, 14] 524,288      BatchNorm2d-142 [-1, 512, 14, 14] 1,024             ReLU-143 [-1, 512, 14, 14] 0  44          Conv2d-144 [-1, 512, 7, 7] 2,359,296      BatchNorm2d-145 [-1, 512, 7, 7] 1,024             ReLU-146 [-1, 512, 7, 7] 0  45         Conv2d-147 [-1, 2048, 7, 7] 1,048,576      BatchNorm2d-148 [-1, 2048, 7, 7] 4,096  46       
   







  Conv2d-149 [-1, 2048, 7, 7] 2,097,152
     BatchNorm2d-150 [-1, 2048, 7, 7] 4,096
            ReLU-151 [-1, 2048, 7, 7] 0
      Cuello de botella-152 [-1, 2048 , 7, 7] 0
 47         Conv2d-153 [-1, 512, 7, 7] 1,048,576
     BatchNorm2d-154 [-1, 512, 7, 7] 1,024
            ReLU-155 [-1, 512, 7, 7] 0
 48          Conv2d-156 [-1, 512, 7, 7] 2,359,296
     BatchNorm2d-157 [-1, 512, 7, 7] 1,024
            ReLU-158 [-1, 512, 7, 7] 0
 49        Conv2d-159 [-1, 2048, 7, 7] 1 048 576
     BatchNorm2d-160 [-1, 2048, 7, 7] 4096
            ReLU-161 [-1, 2048, 7, 7] 0
      Cuello de botella-162 [-1, 2048 , 7, 7] 0
50          Conv2d-163 [-1, 512, 7, 7] 1,048,576
     BatchNorm2d-164 [-1, 512, 7, 7] 1,024
            ReLU-165 [-1, 512, 7, 7] 0
 51         Conv2d-166 [-1, 512, 7, 7] 2,359,296
     BatchNorm2d-167 [-1, 512, 7, 7] 1,024
            ReLU-168 [-1, 512, 7, 7] 0
 52         Conv2d-169 [-1, 2048, 7, 7] 1 048 576
     BatchNorm2d-170 [-1, 2048, 7, 7] 4096
            ReLU-171 [-1, 2048, 7, 7] 0
      Cuello de botella-172 [-1, 2048 , 7, 7] 0
AdaptiveAvgPool2d-173 [-1, 2048, 1, 1] 0
          Lineal-174 [-1, 3] 6,147
==================== ===========================================
Parámetros totales: 23,514,179
Parámetros entrenables : 23,514,179
Parámetros no entrenables: 0
------------------------------------------ ----------------------
Tamaño de entrada (MB): 0,57
Tamaño de paso adelante/atrás (MB): 286,55
Tamaño de parámetros (MB): 89,70
Tamaño total estimado (MB): 376,82
------------------------------------------ ----------------------
GFLOPS totales: 8.263G
Parámetros totales: 23.514M

Código visual:

Referencia de este código: https://blog.csdn.net/qq_34769162/article/details/115567093

# https://blog.csdn.net/qq_34769162/article/details/115567093
import numpy as np

import torch
import torchvision
from PIL import Image
from torchvision import transforms as T

import matplotlib.pyplot as plt
import pylab

import torch
import torchvision

feature_extractor = torchvision.models.resnet50(pretrained=True)
if torch.cuda.is_available():
	feature_extractor.cuda()

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


class SaveOutput:
    def __init__(self):
        self.outputs = []

    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)

    def clear(self):
        self.outputs = []


save_output = SaveOutput()


hook_handles = []

for layer in feature_extractor.modules():
	if isinstance(layer, torch.nn.Conv2d):
		handle = layer.register_forward_hook(save_output)
		hook_handles.append(handle)


from PIL import Image
from torchvision import transforms as T

image = Image.open('img/rot.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)

out = feature_extractor(X)


print(len(save_output.outputs))
# 选择看的层数
# a_list = [0, 1, 6, 15, 28, 35]
a_list = [0, 1, 5, 11, 15, 24,28,43,47]


# 43:512,14,14
#47:512, 7, 7
for i in a_list:
    print(save_output.outputs[i].cpu().detach().squeeze(0).shape)


def grid_gray_image(imgs, each_row: int):
    '''
    imgs shape: batch * size (e.g., 64x32x32, 64 is the number of the gray images, and (32, 32) is the size of each gray image)
    '''
    row_num = imgs.shape[0]//each_row
    for i in range(row_num):
        img = imgs[i*each_row]
        img = (img - img.min()) / (img.max() - img.min())
        for j in range(1, each_row):
            tmp_img = imgs[i*each_row+j]
            tmp_img = (tmp_img - tmp_img.min()) / (tmp_img.max() - tmp_img.min())
            img = np.hstack((img, tmp_img))
        if i == 0:
            ans = img
        else:
            ans = np.vstack((ans, img))
    return ans

# a_list = [0, 1, 5, 11, 15, 24,28,43,47]
img0 = save_output.outputs[0].cpu().detach().squeeze(0)
img0 = grid_gray_image(img0.numpy(), 8)
img1 = save_output.outputs[1].cpu().detach().squeeze(0)
img1 = grid_gray_image(img1.numpy(), 8)
img5 = save_output.outputs[5].cpu().detach().squeeze(0)
img5 = grid_gray_image(img5.numpy(), 8)
img11 = save_output.outputs[11].cpu().detach().squeeze(0)
img11 = grid_gray_image(img11.numpy(), 16)
img15 = save_output.outputs[15].cpu().detach().squeeze(0)
img15 = grid_gray_image(img15.numpy(), 16)
img24 = save_output.outputs[24].cpu().detach().squeeze(0)
img24 = grid_gray_image(img24.numpy(), 16)
img28 = save_output.outputs[28].cpu().detach().squeeze(0)
img28 = grid_gray_image(img28.numpy(), 16)
img43 = save_output.outputs[43].cpu().detach().squeeze(0)
img43 = grid_gray_image(img43.numpy(), 16)
img47 = save_output.outputs[47].cpu().detach().squeeze(0)
img47 = grid_gray_image(img47.numpy(), 16)

# 64,112,112
plt.figure(figsize=(15, 15))
plt.imshow(img0, cmap='gray')


#64,56,56
plt.figure(figsize=(15, 15))
plt.imshow(img1, cmap='gray')

#64,56,56
plt.figure(figsize=(15, 15))
plt.imshow(img5, cmap='gray')

#128,56,56
plt.figure(figsize=(30, 15))
plt.imshow(img11, cmap='gray')

#128,28,28
plt.figure(figsize=(30, 15))
plt.imshow(img15, cmap='gray')

#256,28,28
plt.figure(figsize=(30, 30))
plt.imshow(img24, cmap='gray')

#256,14,14
plt.figure(figsize=(30, 30))
plt.imshow(img28, cmap='gray')

#512,14,14
plt.figure(figsize=(45, 45))
plt.imshow(img43, cmap='gray')

#512,7,7
plt.figure(figsize=(45, 45))
plt.imshow(img47, cmap='gray')

pylab.show()

Imagen original:

Para resnet50, el kernel de convolución de la primera capa convolucional es 7*7, lo que aumenta el canal de imagen en color de tres canales de entrada a 64, y el tamaño se pliega de 224*224 a 112*112.

Desde la imagen de entrada 3, 224, 224 --> 64, 112, 112

 Visualizamos los resultados de extracción de la primera capa convolucional:

La última capa es demasiado abstracta para ver claramente, 512, 7, 7

Supongo que te gusta

Origin blog.csdn.net/m0_63172128/article/details/129709196
Recomendado
Clasificación