Prior has been unclear how to view the parameters and structure of the model, is now studying for a moment.
First put forward resnet20 out
import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init from models.res_utils import DownsampleA, DownsampleC, DownsampleD import math class ResNetBasicblock(nn.Module): expansion = 1 """ RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua) """ def __init__(self, inplanes, planes, stride=1, downsample=None): super(ResNetBasicblock, self).__init__() self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn_a = nn.BatchNorm2d(planes) self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn_b = nn.BatchNorm2d(planes) self.downsample = downsample def forward(self, x): residual = x basicblock = self.conv_a(x) basicblock = self.bn_a(basicblock) basicblock = F.relu(basicblock, inplace=True) basicblock = self.conv_b(basicblock) basicblock = self.bn_b(basicblock) if self.downsample is not None: residual = self.downsample(x) return F.relu(residual + basicblock, inplace=True) class CifarResNet(nn.Module): """ ResNet optimized for the Cifar dataset, as specified in https://arxiv.org/abs/1512.03385.pdf """ def __init__(self, block, depth, num_classes): """ Constructor Args: depth: number of layers. num_classes: number of classes base_width: base width """ super(CifarResNet, self).__init__() #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' layer_blocks = (depth - 2) // 6 print ('CifarResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) self.num_classes = num_classes self.conv_1_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) self.bn_1 = nn.BatchNorm2d(16) self.inplanes = 16 self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) self.avgpool = nn.AvgPool2d(8) self.classifier = nn.Linear(64*block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) #m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal(m.weight) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv_1_3x3(x) x = F.relu(self.bn_1(x), inplace=True) x = self.stage_1(x) x = self.stage_2(x) x = self.stage_3(x) x = self.avgpool(x) x = x.view(x.size(0), -1) return self.classifier(x) def resnet20(num_classes=10): """Constructs a ResNet-20 model for CIFAR-10 (by default) Args: num_classes (uint): number of classes """ model = CifarResNet(ResNetBasicblock, 20, num_classes) return model
In fact, this thing is DownsampleA
class DownsampleA(nn.Module): def __init__(self, nIn, nOut, stride): super(DownsampleA, self).__init__() assert stride == 2 self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) def forward(self, x): x = self.avg(x) return torch.cat((x, x.mul(0)), 1)
Finally, it is the network structure and the layer conv bn pretreatment layer, and the next three stage, each stage are three, the last layer is fully connected and the avgpool
1, model.named_parameters (), iterative print model.named_parameters () will print each iteration name and param element
for name, param in net.named_parameters(): print(name,param.requires_grad) param.requires_grad = False # conv_1_3x3.weight False bn_1.weight False bn_1.bias False stage_1.0.conv_a.weight False stage_1.0.bn_a.weight False stage_1.0.bn_a.bias False stage_1.0.conv_b.weight False stage_1.0.bn_b.weight False stage_1.0.bn_b.bias False stage_1.1.conv_a.weight False stage_1.1.bn_a.weight False stage_1.1.bn_a.bias False stage_1.1.conv_b.weight False stage_1.1.bn_b.weight False stage_1.1.bn_b.bias False stage_1.2.conv_a.weight False stage_1.2.bn_a.weight False stage_1.2.bn_a.bias False stage_1.2.conv_b.weight False stage_1.2.bn_b.weight False stage_1.2.bn_b.bias False stage_2.0.conv_a.weight False stage_2.0.bn_a.weight False stage_2.0.bn_a.bias False stage_2.0.conv_b.weight False stage_2.0.bn_b.weight False stage_2.0.bn_b.bias False stage_2.1.conv_a.weight False stage_2.1.bn_a.weight False stage_2.1.bn_a.bias False stage_2.1.conv_b.weight False stage_2.1.bn_b.weight False stage_2.1.bn_b.bias False stage_2.2.conv_a.weight False stage_2.2.bn_a.weight False stage_2.2.bn_a.bias False stage_2.2.conv_b.weight False stage_2.2.bn_b.weight False stage_2.2.bn_b.bias False stage_3.0.conv_a.weight False stage_3.0.bn_a.weight False stage_3.0.bn_a.bias False stage_3.0.conv_b.weight False stage_3.0.bn_b.weight False stage_3.0.bn_b.bias False stage_3.1.conv_a.weight False stage_3.1.bn_a.weight False stage_3.1.bn_a.bias False stage_3.1.conv_b.weight False stage_3.1.bn_b.weight False stage_3.1.bn_b.bias False stage_3.2.conv_a.weight False stage_3.2.bn_a.weight False stage_3.2.bn_a.bias False stage_3.2.conv_b.weight False stage_3.2.bn_b.weight False stage_3.2.bn_b.bias False classifier.weight False classifier.bias False
And you can change the parameters of the trainable attributes, the first print is True, this is the second time that a False
2, model.parameters (), iterative print model.parameters () will print each iteration param elements will not print the name, which is the difference between him and named_parameters, both of which can be used to change the properties of requires_grad
for index, param in enumerate(net.parameters()): print(param.shape) # torch.Size([16, 3, 3, 3]) torch.Size([16]) torch.Size([16]) torch.Size([16, 16, 3, 3]) torch.Size([16]) torch.Size([16]) torch.Size([16, 16, 3, 3]) torch.Size([16]) torch.Size([16]) torch.Size([16, 16, 3, 3]) torch.Size([16]) torch.Size([16]) torch.Size([16, 16, 3, 3]) torch.Size([16]) torch.Size([16]) torch.Size([16, 16, 3, 3]) torch.Size([16]) torch.Size([16]) torch.Size([16, 16, 3, 3]) torch.Size([16]) torch.Size([16]) torch.Size([32, 16, 3, 3]) torch.Size([32]) torch.Size([32]) torch.Size([32, 32, 3, 3]) torch.Size([32]) torch.Size([32]) torch.Size([32, 32, 3, 3]) torch.Size([32]) torch.Size([32]) torch.Size([32, 32, 3, 3]) torch.Size([32]) torch.Size([32]) torch.Size([32, 32, 3, 3]) torch.Size([32]) torch.Size([32]) torch.Size([32, 32, 3, 3]) torch.Size([32]) torch.Size([32]) torch.Size([64, 32, 3, 3]) torch.Size([64]) torch.Size([64]) torch.Size([64, 64, 3, 3]) torch.Size([64]) torch.Size([64]) torch.Size([64, 64, 3, 3]) torch.Size([64]) torch.Size([64]) torch.Size([64, 64, 3, 3]) torch.Size([64]) torch.Size([64]) torch.Size([64, 64, 3, 3]) torch.Size([64]) torch.Size([64]) torch.Size([64, 64, 3, 3]) torch.Size([64]) torch.Size([64]) torch.Size([10, 64]) torch.Size([10])
The size of these parameters can be seen