Pytorch中的网络相关参数展示【特征图】和模型的flops计算以及推理时间计算

一、网络相关参数展示

(1)网络结构可视化,如下图红框部分所示

(2)特征图可视化和参数量计算

需要安装 torchsummary包:

pip install torchsummary

实例如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Net().to(device)

summary(model, (1, 28, 28))

结果如下: 

#         每一层类型               特征图shape         每层的参数量
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 10, 24, 24]             260
            Conv2d-2             [-1, 20, 8, 8]           5,020
         Dropout2d-3             [-1, 20, 8, 8]               0
            Linear-4                   [-1, 50]          16,050
            Linear-5                   [-1, 10]             510
================================================================
Total params: 21,840       # 模型整体的参数量(上面层参数量相加)
Trainable params: 21,840 
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00     # 图片预处理后的大小
Forward/backward pass size (MB): 0.06   # 正向/反向传播一次的内存大小
Params size (MB): 0.08
Estimated Total Size (MB): 0.15
--------------------------------------------------------------

参考链接:

https://github.com/sksq96/pytorch-summary

二、模型的flops计算

(1)安装thop

pip install thop

(2)基本使用

from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))

扫描二维码关注公众号,回复: 12248076 查看本文章

参考:https://www.jianshu.com/p/6514b8fb1ada

          https://zhuanlan.zhihu.com/p/337810633

着重参考:https://blog.csdn.net/junmuzi/article/details/83109660?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control

三、模型平均推理时间

   from torchvision.models import googlenet
   from thop import profile
   model = googlenet()
   input = torch.randn(1, 3, 224, 224)
   flops, params = profile(model, inputs=(input,))
   print("flops:",flops)
   print('params',params)

参考:https://www.jianshu.com/p/cbada26ea29d?from=groupmessage

猜你喜欢

转载自blog.csdn.net/MasterCayman/article/details/111478100
今日推荐