Pytorchのネットワーク関連パラメータは[機能マップ]とモデルフロップスの計算と推論時間の計算を示しています

1.ネットワーク関連のパラメータ表示

(1)下の赤いボックスに示されているネットワーク構造の視覚化

(2)フィーチャマップの視覚化とパラメータ計算

torchsummaryパッケージをインストールする必要があります:

pip install torchsummary

例は次のとおりです。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Net().to(device)

summary(model, (1, 28, 28))

結果は次のとおりです。 

#         每一层类型               特征图shape         每层的参数量
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 10, 24, 24]             260
            Conv2d-2             [-1, 20, 8, 8]           5,020
         Dropout2d-3             [-1, 20, 8, 8]               0
            Linear-4                   [-1, 50]          16,050
            Linear-5                   [-1, 10]             510
================================================================
Total params: 21,840       # 模型整体的参数量(上面层参数量相加)
Trainable params: 21,840 
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00     # 图片预处理后的大小
Forward/backward pass size (MB): 0.06   # 正向/反向传播一次的内存大小
Params size (MB): 0.08
Estimated Total Size (MB): 0.15
--------------------------------------------------------------

参照リンク:

https://github.com/sksq96/pytorch-summary

次に、モデルのフロップ計算

(1)thopをインストールします

pip install thop

(2)基本的な使い方

from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))

参照:https//www.jianshu.com/p/6514b8fb1ada

          https://zhuanlan.zhihu.com/p/337810633

参照に焦点を当てる:https//blog.csdn.net/junmuzi/article/details/83109660?utm_medium = distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control&depth_1-utm_source = distribute.pc_relevant.none-task-ブログ-BlogCommendFromMachineLearnPai2-2.control

3.モデルの平均推論時間

   from torchvision.models import googlenet
   from thop import profile
   model = googlenet()
   input = torch.randn(1, 3, 224, 224)
   flops, params = profile(model, inputs=(input,))
   print("flops:",flops)
   print('params',params)

参照:https//www.jianshu.com/p/cbada26ea29d?from = groupmessage

おすすめ

転載: blog.csdn.net/MasterCayman/article/details/111478100