Cuando estamos entrenando imágenes, ¿queremos ver la extracción de cada mapa de características durante la extracción específica? Después de buscar mucho, el trabajo duro finalmente vale la pena. Lo encontré y modifiqué el código:
codigo resnet:
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
import math
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
class GhostModule(nn.Module):
def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
super(GhostModule, self).__init__()
self.oup = oup
init_channels = math.ceil(oup / ratio)
new_channels = init_channels*(ratio-1)
self.primary_conv = nn.Sequential(
nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
nn.BatchNorm2d(init_channels),
nn.ReLU(inplace=True) if relu else nn.Sequential(),
)
self.cheap_operation = nn.Sequential(
nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
nn.BatchNorm2d(new_channels),
nn.ReLU(inplace=True) if relu else nn.Sequential(),
)
def forward(self, x):
x1 = self.primary_conv(x)
x2 = self.cheap_operation(x1)
out = torch.cat([x1,x2], dim=1)
return out[:,:self.oup,:,:]
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.block = block
self.groups = groups
self.base_width = width_per_group
# 224,224,3 -> 112,112,64
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
# 112,112,64 -> 56,56,64
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 56,56,64 -> 56,56,256
self.layer1 = self._make_layer(block, 64, layers[0])
# 56,56,256 -> 28,28,512
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
# 28,28,512 -> 14,14,1024
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
# 14,14,1024 -> 7,7,2048
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
# 7,7,2048 -> 2048
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# 2048 -> num_classes
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
# Conv_block
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
# identity_block
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def freeze_backbone(self):
backbone = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4]
for module in backbone:
for param in module.parameters():
param.requires_grad = False
def Unfreeze_backbone(self):
backbone = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4]
for module in backbone:
for param in module.parameters():
param.requires_grad = True
def resnet18(pretrained=False, progress=True, num_classes=1000):
model = ResNet(BasicBlock, [2, 2, 2, 2])
if pretrained:
state_dict = load_state_dict_from_url(model_urls['resnet18'], model_dir='./model_data',
progress=progress)
model.load_state_dict(state_dict)
if num_classes!=1000:
model.fc = nn.Linear(512 * model.block.expansion, num_classes)
return model
def resnet34(pretrained=False, progress=True, num_classes=1000):
model = ResNet(BasicBlock, [3, 4, 6, 3])
if pretrained:
state_dict = load_state_dict_from_url(model_urls['resnet34'], model_dir='./model_data',
progress=progress)
model.load_state_dict(state_dict)
if num_classes!=1000:
model.fc = nn.Linear(512 * model.block.expansion, num_classes)
return model
def resnet50(pretrained=False, progress=True, num_classes=1000):
model = ResNet(Bottleneck, [3, 4, 6, 3])
if pretrained:
state_dict = load_state_dict_from_url(model_urls['resnet50'], model_dir='./model_data',
progress=progress)
model.load_state_dict(state_dict)
if num_classes!=1000:
model.fc = nn.Linear(512 * model.block.expansion, num_classes)
return model
def resnet101(pretrained=False, progress=True, num_classes=1000):
model = ResNet(Bottleneck, [3, 4, 23, 3])
if pretrained:
state_dict = load_state_dict_from_url(model_urls['resnet101'], model_dir='./model_data',
progress=progress)
model.load_state_dict(state_dict)
if num_classes!=1000:
model.fc = nn.Linear(512 * model.block.expansion, num_classes)
return model
def resnet152(pretrained=False, progress=True, num_classes=1000):
model = ResNet(Bottleneck, [3, 8, 36, 3])
if pretrained:
state_dict = load_state_dict_from_url(model_urls['resnet152'], model_dir='./model_data',
progress=progress)
model.load_state_dict(state_dict)
if num_classes!=1000:
model.fc = nn.Linear(512 * model.block.expansion, num_classes)
return model
Puede ver la estructura de la red:
#--------------------------------------------#
# 该部分代码只用于看网络结构,并非测试代码
#--------------------------------------------#
import torch
from thop import clever_format, profile
from torchsummary import summary
from nets import get_model_from_name
# from nets import resnet_cbam # 使用哪个引入哪个即可
if __name__ == "__main__":
input_shape = [224, 224]
num_classes = 3 #写自己的分类个数,如果是训练图像分割,要多分一个背景,比如猫狗两种,num_classes=2+1
# backbone = "mobilenetv2"
backbone = "resnet50"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model_from_name[backbone](num_classes=num_classes, pretrained=False).to(device)
summary(model, (3, input_shape[0], input_shape[1]))
dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
flops, params = profile(model.to(device), (dummy_input, ), verbose=False)
#--------------------------------------------------------#
# flops * 2是因为profile没有将卷积作为两个operations
# 有些论文将卷积算乘法、加法两个operations。此时乘2
# 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
# 本代码选择乘2,参考YOLOX。
#--------------------------------------------------------#
flops = flops * 2
flops, params = clever_format([flops, params], "%.3f")
print('Total GFLOPS: %s' % (flops))
print('Total params: %s' % (params))
La estructura es la siguiente: (resnet50)
Dos de los módulos residuales Conv Block se usan para cambiar la cantidad de canales, y Identity Block se usa para profundizar la red
4 circunvoluciones en el primer bloque residual, 3 en el segundo
(4+3+3) + (4+3+3+3) + (4+3+3+3+3+3+3) + (4+3+3)==52
Además, hay una convolución 7*7==53
-------------------------------------------------- --------------
lista Capa (tipo) Parámetro de forma de salida #
=========================== =====================================
0 Conv2d-1 [-1, 64, 112, 112 ] 9408 BatchNorm2d-2 [-1, 64, 112, 112] 128 ReLU-3 [-1, 64, 112, 112] 0 MaxPool2d-4 [-1, 64, 56, 56] 0 1 Conv2d-5 [- 1, 64, 56, 56] 4096 BatchNorm2d-6 [-1, 64, 56, 56] 128 ReLU-7 [-1, 64, 56, 56] 0 2
Conv2d-8 [-1, 64, 56, 56] 36 864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
3 Conv2d-11 [-1, 256, 56, 56] 16 384
BatchNorm2d-12 [-1, 256, 56, 56] 512
4 Conv2d-13 [-1, 256, 56, 56] 16 384
BatchNorm2d-14 [-1, 256, 56, 56] 512
ReLU-15 [-1, 256, 56, 56] 0
Cuello de botella-16 [-1, 256, 56, 56] 0
5 Conv2d-17 [-1, 64, 56, 56] 16 384 BatchNorm2d-18 [-1, 64, 56, 56] 128
ReLU-19 [-1, 64, 56, 56] 0
6 Conv2d-20 [-1, 64, 56, 56] 36 864
BatchNorm2d-21 [-1, 64, 56, 56] 128
ReLU-22 [-1, 64, 56, 56] 0
7 Conv2d-23 [-1, 256, 56, 56] 16 384
BatchNorm2d-24 [-1, 256, 56, 56] 512
ReLU-25 [-1, 256, 56, 56] 0
Cuello de botella-26 [-1, 256, 56, 56] 0
8 Conv2d-27 [-1, 64, 56, 56] 16 384
BatchNorm2d-28 [-1, 64, 56, 56] 128
ReLU-29 [-1, 64, 56, 56] 0
9 Conv2d-30 [-1, 64, 56, 56] 36 864
BatchNorm2d-31 [-1, 64, 56, 56] 128
ReLU-32 [-1, 64, 56, 56] 0
10 Conv2d-33 [-1, 256, 56, 56] 16 384
BatchNorm2d-34 [-1, 256, 56, 56] 512
ReLU-35 [-1, 256, 56, 56] 0
Cuello de botella-36 [-1, 256, 56, 56] 0
11 Conv2d-37 [-1, 128, 56, 56] 32 768 BatchNorm2d-38 [-1, 128, 56, 56] 256
ReLU-39 [-1, 128, 56, 56] 0
12 Conv2d-40 [-1, 128, 28, 28] 147 456
BatchNorm2d-41 [-1, 128, 28, 28] 256
ReLU-42 [-1, 128, 28, 28] 0
13 Conv2d-43 [-1, 512, 28, 28] 65 536
BatchNorm2d-44 [-1, 512, 28, 28] 1024
14 Conv2d-45 [-1, 512, 28, 28] 131 072
BatchNorm2d-46 [-1, 512, 28, 28] 1024
ReLU-47 [-1, 512, 28, 28] 0
Cuello de botella-48 [-1, 512, 28, 28] 0 15
Conv2d-49 [-1, 128, 28, 28] 65 536
BatchNorm2d-50 [-1, 128, 28, 28] 256
ReLU-51 [-1, 128, 28, 28] 0
16 Conv2d-52 [-1, 128, 28, 28] 147 456
BatchNorm2d-53 [-1, 128, 28, 28] 256
ReLU-54 [-1, 128, 28, 28] 0
17 Conv2d-55 [-1, 512, 28, 28] 65 536
BatchNorm2d-56 [-1, 512, 28, 28] 1,024
ReLU-57 [-1, 512, 28, 28] 0
Cuello de botella-58 [-1, 512, 28, 28] 0
18 Conv2d-59 [-1, 128, 28, 28] 65 536
BatchNorm2d-60 [-1, 128, 28, 28] 256
ReLU-61 [-1, 128, 28, 28] 0
19 Conv2d-62 [-1, 128, 28, 28] 147 456
BatchNorm2d-63 [-1, 128, 28, 28] 256
ReLU-64 [-1, 128, 28, 28] 0
20 Conv2d-65 [-1, 512, 28, 28] 65 536
BatchNorm2d-66 [-1, 512, 28, 28] 1024
ReLU-67 [-1, 512, 28, 28] 0
Cuello de botella-68 [-1, 512, 28, 28] 0
21 Conv2d-69 [-1, 128, 28, 28] 65 536
BatchNorm2d-70 [-1, 128, 28, 28] 256
ReLU-71 [-1, 128, 28, 28] 0
22 Conv2d-72 [-1, 128, 28, 28] 147 456
BatchNorm2d-73 [-1, 128, 28, 28] 256
ReLU-74 [-1, 128, 28, 28] 0
23 Conv2d-75 [-1, 512, 28, 28] 65 536
BatchNorm2d-76 [-1, 512, 28, 28] 1024
ReLU-77 [-1, 512, 28, 28] 0
Cuello de botella-78 [-1, 512, 28, 28] 0
24 Conv2d-79 [-1, 256, 28, 28] 131 072
BatchNorm2d-80 [-1, 256, 28, 28] 512
ReLU-81 [-1, 256, 28, 28] 0
25 Conv2d-82 [-1, 256, 14, 14] 589 824
BatchNorm2d-83 [-1, 256, 14, 14] 512
ReLU-84 [-1, 256, 14, 14] 0
26 Conv2d-85 [-1, 1024, 14, 14] 262 144
BatchNorm2d-86 [-1, 1024, 14, 14] 2,048
27 Conv2d-87 [-1, 1024, 14, 14] 524,288
BatchNorm2d-88 [-1, 1024, 14, 14] 2,048
ReLU-89 [-1, 1024, 14, 14] 0
Cuello de botella-90 [-1, 1024, 14, 14] 0 28 Conv2d-91 [-1, 256, 14, 14] 262 144 BatchNorm2d-92 [-1, 256, 14, 14] 512 ReLU-93 [-1, 256, 14, 14] 0 29 Conv2d-94 [-1, 256, 14, 14] 589 824 BatchNorm2d-95 [-1, 256, 14, 14] 512 ReLU-96 [-1, 256, 14, 14] 0 30 Conv2d-97 [-1, 1024, 14, 14] 262,144 BatchNorm2d-98 [-1, 1024, 14, 14] 2,048
ReLU-99 [-1, 1024, 14, 14] 0
Cuello de botella-100 [-1, 1024, 14, 14] 0
31 Conv2d-101 [-1, 256, 14, 14] 262 144
BatchNorm2d-102 [-1, 256, 14, 14] 512
ReLU-103 [-1, 256, 14, 14] 0
32 Conv2d-104 [-1, 256, 14, 14] 589 824
BatchNorm2d-105 [-1, 256, 14, 14] 512
ReLU-106 [-1, 256, 14, 14] 0
33 Conv2d-107 [-1, 1024, 14, 14] 262,144
BatchNorm2d-108 [-1, 1024, 14, 14] 2,048
ReLU-109 [-1, 1024, 14, 14] 0
Cuello de botella-110 [-1, 1024, 14, 14] 0
34 Conv2d-111 [-1, 256, 14, 14] 262,144
BatchNorm2d-112 [-1, 256, 14, 14] 512
ReLU-113 [-1, 256, 14, 14] 0
35 Conv2d-114 [-1, 256, 14, 14] 589,824
BatchNorm2d-115 [-1, 256, 14, 14] 512
ReLU-116 [-1, 256, 14, 14] 0
36 Conv2d-117 [-1, 1024, 14, 14] 262,144
BatchNorm2d-118 [-1, 1024, 14, 14] 2,048
ReLU-119 [-1, 1024, 14, 14] 0
Cuello de botella-120 [-1, 1024, 14, 14] 0
37 Conv2d-121 [-1, 256, 14, 14] 262 144
BatchNorm2d-122 [-1, 256, 14, 14] 512
ReLU-123 [-1, 256, 14, 14] 0
38 Conv2d-124 [-1, 256, 14, 14] 589 824
BatchNorm2d-125 [-1, 256, 14, 14] 512
ReLU-126 [-1, 256, 14, 14] 0
39 Conv2d-127 [-1, 1024, 14, 14] 262,144
BatchNorm2d-128 [-1, 1024, 14, 14] 2,048
ReLU-129 [-1, 1024, 14, 14] 0
Cuello de botella-130 [-1, 1024, 14, 14] 0
40 Conv2d-131 [-1, 256, 14, 14] 262 144
BatchNorm2d-132 [-1, 256, 14, 14] 512
ReLU-133 [-1, 256, 14, 14] 0
41 Conv2d-134 [-1, 256, 14, 14] 589,824
BatchNorm2d-135 [-1, 256, 14, 14] 512
ReLU-136 [-1, 256, 14, 14] 0
42 Conv2d-137 [-1, 1024, 14, 14] 262,144
BatchNorm2d-138 [-1, 1024, 14, 14] 2,048
ReLU-139 [-1, 1024, 14, 14] 0
Cuello de botella-140 [-1, 1024, 14, 14] 0 43 Conv2d-141 [-1, 512, 14, 14] 524,288 BatchNorm2d-142 [-1, 512, 14, 14] 1,024 ReLU-143 [-1, 512, 14, 14] 0 44 Conv2d-144 [-1, 512, 7, 7] 2,359,296 BatchNorm2d-145 [-1, 512, 7, 7] 1,024 ReLU-146 [-1, 512, 7, 7] 0 45 Conv2d-147 [-1, 2048, 7, 7] 1,048,576 BatchNorm2d-148 [-1, 2048, 7, 7] 4,096 46
Conv2d-149 [-1, 2048, 7, 7] 2,097,152
BatchNorm2d-150 [-1, 2048, 7, 7] 4,096
ReLU-151 [-1, 2048, 7, 7] 0
Cuello de botella-152 [-1, 2048 , 7, 7] 0
47 Conv2d-153 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-154 [-1, 512, 7, 7] 1,024
ReLU-155 [-1, 512, 7, 7] 0
48 Conv2d-156 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-157 [-1, 512, 7, 7] 1,024
ReLU-158 [-1, 512, 7, 7] 0
49 Conv2d-159 [-1, 2048, 7, 7] 1 048 576
BatchNorm2d-160 [-1, 2048, 7, 7] 4096
ReLU-161 [-1, 2048, 7, 7] 0
Cuello de botella-162 [-1, 2048 , 7, 7] 0
50 Conv2d-163 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-164 [-1, 512, 7, 7] 1,024
ReLU-165 [-1, 512, 7, 7] 0
51 Conv2d-166 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-167 [-1, 512, 7, 7] 1,024
ReLU-168 [-1, 512, 7, 7] 0
52 Conv2d-169 [-1, 2048, 7, 7] 1 048 576
BatchNorm2d-170 [-1, 2048, 7, 7] 4096
ReLU-171 [-1, 2048, 7, 7] 0
Cuello de botella-172 [-1, 2048 , 7, 7] 0
AdaptiveAvgPool2d-173 [-1, 2048, 1, 1] 0
Lineal-174 [-1, 3] 6,147
==================== ===========================================
Parámetros totales: 23,514,179
Parámetros entrenables : 23,514,179
Parámetros no entrenables: 0
------------------------------------------ ----------------------
Tamaño de entrada (MB): 0,57
Tamaño de paso adelante/atrás (MB): 286,55
Tamaño de parámetros (MB): 89,70
Tamaño total estimado (MB): 376,82
------------------------------------------ ----------------------
GFLOPS totales: 8.263G
Parámetros totales: 23.514M
Código visual:
Referencia de este código: https://blog.csdn.net/qq_34769162/article/details/115567093
# https://blog.csdn.net/qq_34769162/article/details/115567093
import numpy as np
import torch
import torchvision
from PIL import Image
from torchvision import transforms as T
import matplotlib.pyplot as plt
import pylab
import torch
import torchvision
feature_extractor = torchvision.models.resnet50(pretrained=True)
if torch.cuda.is_available():
feature_extractor.cuda()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
save_output = SaveOutput()
hook_handles = []
for layer in feature_extractor.modules():
if isinstance(layer, torch.nn.Conv2d):
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
from PIL import Image
from torchvision import transforms as T
image = Image.open('img/rot.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)
out = feature_extractor(X)
print(len(save_output.outputs))
# 选择看的层数
# a_list = [0, 1, 6, 15, 28, 35]
a_list = [0, 1, 5, 11, 15, 24,28,43,47]
# 43:512,14,14
#47:512, 7, 7
for i in a_list:
print(save_output.outputs[i].cpu().detach().squeeze(0).shape)
def grid_gray_image(imgs, each_row: int):
'''
imgs shape: batch * size (e.g., 64x32x32, 64 is the number of the gray images, and (32, 32) is the size of each gray image)
'''
row_num = imgs.shape[0]//each_row
for i in range(row_num):
img = imgs[i*each_row]
img = (img - img.min()) / (img.max() - img.min())
for j in range(1, each_row):
tmp_img = imgs[i*each_row+j]
tmp_img = (tmp_img - tmp_img.min()) / (tmp_img.max() - tmp_img.min())
img = np.hstack((img, tmp_img))
if i == 0:
ans = img
else:
ans = np.vstack((ans, img))
return ans
# a_list = [0, 1, 5, 11, 15, 24,28,43,47]
img0 = save_output.outputs[0].cpu().detach().squeeze(0)
img0 = grid_gray_image(img0.numpy(), 8)
img1 = save_output.outputs[1].cpu().detach().squeeze(0)
img1 = grid_gray_image(img1.numpy(), 8)
img5 = save_output.outputs[5].cpu().detach().squeeze(0)
img5 = grid_gray_image(img5.numpy(), 8)
img11 = save_output.outputs[11].cpu().detach().squeeze(0)
img11 = grid_gray_image(img11.numpy(), 16)
img15 = save_output.outputs[15].cpu().detach().squeeze(0)
img15 = grid_gray_image(img15.numpy(), 16)
img24 = save_output.outputs[24].cpu().detach().squeeze(0)
img24 = grid_gray_image(img24.numpy(), 16)
img28 = save_output.outputs[28].cpu().detach().squeeze(0)
img28 = grid_gray_image(img28.numpy(), 16)
img43 = save_output.outputs[43].cpu().detach().squeeze(0)
img43 = grid_gray_image(img43.numpy(), 16)
img47 = save_output.outputs[47].cpu().detach().squeeze(0)
img47 = grid_gray_image(img47.numpy(), 16)
# 64,112,112
plt.figure(figsize=(15, 15))
plt.imshow(img0, cmap='gray')
#64,56,56
plt.figure(figsize=(15, 15))
plt.imshow(img1, cmap='gray')
#64,56,56
plt.figure(figsize=(15, 15))
plt.imshow(img5, cmap='gray')
#128,56,56
plt.figure(figsize=(30, 15))
plt.imshow(img11, cmap='gray')
#128,28,28
plt.figure(figsize=(30, 15))
plt.imshow(img15, cmap='gray')
#256,28,28
plt.figure(figsize=(30, 30))
plt.imshow(img24, cmap='gray')
#256,14,14
plt.figure(figsize=(30, 30))
plt.imshow(img28, cmap='gray')
#512,14,14
plt.figure(figsize=(45, 45))
plt.imshow(img43, cmap='gray')
#512,7,7
plt.figure(figsize=(45, 45))
plt.imshow(img47, cmap='gray')
pylab.show()
Imagen original:
Para resnet50, el kernel de convolución de la primera capa convolucional es 7*7, lo que aumenta el canal de imagen en color de tres canales de entrada a 64, y el tamaño se pliega de 224*224 a 112*112.
Desde la imagen de entrada 3, 224, 224 --> 64, 112, 112
Visualizamos los resultados de extracción de la primera capa convolucional:
La última capa es demasiado abstracta para ver claramente, 512, 7, 7