Los parámetros relacionados con la red en Pytorch muestran [mapa de características] y modelan el cálculo de fallos y el cálculo del tiempo de razonamiento

1. Visualización de parámetros relacionados con la red

(1) Visualización de la estructura de la red, como se muestra en el cuadro rojo a continuación

(2) Visualización de mapas de características y cálculo de parámetros

Necesita instalar el paquete torchsummary:

pip install torchsummary

Los ejemplos son los siguientes:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Net().to(device)

summary(model, (1, 28, 28))

Los resultados son los siguientes: 

#         每一层类型               特征图shape         每层的参数量
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 10, 24, 24]             260
            Conv2d-2             [-1, 20, 8, 8]           5,020
         Dropout2d-3             [-1, 20, 8, 8]               0
            Linear-4                   [-1, 50]          16,050
            Linear-5                   [-1, 10]             510
================================================================
Total params: 21,840       # 模型整体的参数量(上面层参数量相加)
Trainable params: 21,840 
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00     # 图片预处理后的大小
Forward/backward pass size (MB): 0.06   # 正向/反向传播一次的内存大小
Params size (MB): 0.08
Estimated Total Size (MB): 0.15
--------------------------------------------------------------

Link de referencia:

https://github.com/sksq96/pytorch-summary

En segundo lugar, el cálculo de los fracasos del modelo.

(1) Instalar thop

pip install thop

(2) Uso básico

from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))

Referencia: https://www.jianshu.com/p/6514b8fb1ada

          https://zhuanlan.zhihu.com/p/337810633

Centrarse en la referencia: https://blog.csdn.net/junmuzi/article/details/83109660?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control&depth_1-utm_source=distribute.nonec_relevant. blog -BlogCommendFromMachineLearnPai2-2.control

3. El tiempo medio de inferencia del modelo.

   from torchvision.models import googlenet
   from thop import profile
   model = googlenet()
   input = torch.randn(1, 3, 224, 224)
   flops, params = profile(model, inputs=(input,))
   print("flops:",flops)
   print('params',params)

Referencia: https://www.jianshu.com/p/cbada26ea29d?from=groupmessage

Supongo que te gusta

Origin blog.csdn.net/MasterCayman/article/details/111478100
Recomendado
Clasificación