在pytorch中查看模型model参数parameters

在pytorch中查看模型model参数parameters

示例1:pytorch自带的faster r-cnn模型

import torch
import torchvision

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

for name, p in model.named_parameters():
    print(name)
    print(p.requires_grad)
    print(...)

#或者

for p in model.parameters():
    print(p)
    print(...)

示例2:自定义网络模型

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
        self.features = self._vgg_layers(cfg)

    def _vgg_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
                        nn.BatchNorm2d(x),
                        nn.ReLU(inplace=True)
                        ]
                in_channels = x
            
        return nn.Sequential(*layers)

    def forward(self, data):
        out_map = self.features(data)
        return out_map
    
Model = Net()

for name, p in model.named_parameters():
    print(name)
    print(p.requires_grad)
    print(...)

#或者

for p in model.parameters():
    print(p)
    print(...)

在自定义网络中,model.parameters()方法继承自nn.Module

猜你喜欢

转载自blog.csdn.net/qq_38600065/article/details/105552816