【pytorch】torchsummary打印模型结构和参数信息

目录

一. 安装

二. 调用

三. 参数设置

四. 使用示例


一. 安装

Anaconda prompt 中激活虚拟环境

输入pip install torchsummary进行安装

pip install torchsummary

二. 调用

from torchsummary import summary

可以成功调用就是安装成功 

三. 参数设置

torchsummary.summary(model, input_size, batch_size=-1, device="cuda")
  • model:pytorch 模型
  • input_size:模型输入size,形状为 C,H ,W
  • batch_size:batch_size,默认为 -1,在展示模型每层输出的形状时显示的 batch_size
  • device:“cuda"或者"cpu”,默认‘cuda’,如果用cpu,记得更改

四. 使用示例

from torchsummary import summary
import torch
from torchvision.models import resnet18  # 以 resnet18 为例

input_shape = [512, 512]  # 设置输入大小
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 选择是否使用GPU
model = resnet18().to(device)   # 实例化网络
summary(model, (3, input_shape[0], input_shape[1]))

 输出结果:


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 256, 256]           9,408
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
         MaxPool2d-4         [-1, 64, 128, 128]               0
            Conv2d-5         [-1, 64, 128, 128]          36,864
       BatchNorm2d-6         [-1, 64, 128, 128]             128
              ReLU-7         [-1, 64, 128, 128]               0
            Conv2d-8         [-1, 64, 128, 128]          36,864
       BatchNorm2d-9         [-1, 64, 128, 128]             128
             ReLU-10         [-1, 64, 128, 128]               0
       BasicBlock-11         [-1, 64, 128, 128]               0
           Conv2d-12         [-1, 64, 128, 128]          36,864
      BatchNorm2d-13         [-1, 64, 128, 128]             128
             ReLU-14         [-1, 64, 128, 128]               0
           Conv2d-15         [-1, 64, 128, 128]          36,864
      BatchNorm2d-16         [-1, 64, 128, 128]             128
             ReLU-17         [-1, 64, 128, 128]               0
       BasicBlock-18         [-1, 64, 128, 128]               0
           Conv2d-19          [-1, 128, 64, 64]          73,728
      BatchNorm2d-20          [-1, 128, 64, 64]             256
             ReLU-21          [-1, 128, 64, 64]               0
           Conv2d-22          [-1, 128, 64, 64]         147,456
      BatchNorm2d-23          [-1, 128, 64, 64]             256
           Conv2d-24          [-1, 128, 64, 64]           8,192
      BatchNorm2d-25          [-1, 128, 64, 64]             256
             ReLU-26          [-1, 128, 64, 64]               0
       BasicBlock-27          [-1, 128, 64, 64]               0
           Conv2d-28          [-1, 128, 64, 64]         147,456
      BatchNorm2d-29          [-1, 128, 64, 64]             256
             ReLU-30          [-1, 128, 64, 64]               0
           Conv2d-31          [-1, 128, 64, 64]         147,456
      BatchNorm2d-32          [-1, 128, 64, 64]             256
             ReLU-33          [-1, 128, 64, 64]               0
       BasicBlock-34          [-1, 128, 64, 64]               0
           Conv2d-35          [-1, 256, 32, 32]         294,912
      BatchNorm2d-36          [-1, 256, 32, 32]             512
             ReLU-37          [-1, 256, 32, 32]               0
           Conv2d-38          [-1, 256, 32, 32]         589,824
      BatchNorm2d-39          [-1, 256, 32, 32]             512
           Conv2d-40          [-1, 256, 32, 32]          32,768
      BatchNorm2d-41          [-1, 256, 32, 32]             512
             ReLU-42          [-1, 256, 32, 32]               0
       BasicBlock-43          [-1, 256, 32, 32]               0
           Conv2d-44          [-1, 256, 32, 32]         589,824
      BatchNorm2d-45          [-1, 256, 32, 32]             512
             ReLU-46          [-1, 256, 32, 32]               0
           Conv2d-47          [-1, 256, 32, 32]         589,824
      BatchNorm2d-48          [-1, 256, 32, 32]             512
             ReLU-49          [-1, 256, 32, 32]               0
       BasicBlock-50          [-1, 256, 32, 32]               0
           Conv2d-51          [-1, 512, 16, 16]       1,179,648
      BatchNorm2d-52          [-1, 512, 16, 16]           1,024
             ReLU-53          [-1, 512, 16, 16]               0
           Conv2d-54          [-1, 512, 16, 16]       2,359,296
      BatchNorm2d-55          [-1, 512, 16, 16]           1,024
           Conv2d-56          [-1, 512, 16, 16]         131,072
      BatchNorm2d-57          [-1, 512, 16, 16]           1,024
             ReLU-58          [-1, 512, 16, 16]               0
       BasicBlock-59          [-1, 512, 16, 16]               0
           Conv2d-60          [-1, 512, 16, 16]       2,359,296
      BatchNorm2d-61          [-1, 512, 16, 16]           1,024
             ReLU-62          [-1, 512, 16, 16]               0
           Conv2d-63          [-1, 512, 16, 16]       2,359,296
      BatchNorm2d-64          [-1, 512, 16, 16]           1,024
             ReLU-65          [-1, 512, 16, 16]               0
       BasicBlock-66          [-1, 512, 16, 16]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 328.01
Params size (MB): 44.59
Estimated Total Size (MB): 375.60
----------------------------------------------------------------

Process finished with exit code 0

如图,torchsummary 可以查看网络的顺序结构,显示每一层的类型、out shape和参数量; 还有网络参数量,网络模型大小; fp/bp 一次需要的内存大小等信息。 

猜你喜欢

转载自blog.csdn.net/m0_70813473/article/details/130204898