Use fvcore para calcular la cantidad de parámetros y FLOP de un modelo en Pytorch

fvcoreEs una biblioteca central liviana de código abierto de Facebook, que proporciona funciones comunes y básicas en varios marcos de visión por computadora. Estos incluyen parámetros de modelos estadísticos y FLOP.

fvcore es una biblioteca central liviana que proporciona la funcionalidad más común y esencial compartida en varios marcos de visión por computadora

Dirección de código abierto del proyecto:
https://github.com/facebookresearch/fvcore

Instalar fvcore en entorno python

pip install fvcore

Ejemplo:
supongamos que necesito calcular la cantidad de parámetros para el siguiente resnet50 junto con el parámetro FLOP.

import torch
from torchvision.models import resnet50
from fvcore.nn import FlopCountAnalysis, parameter_count_table

# 创建resnet50网络
model = resnet50(num_classes=1000)

# 创建输入网络的tensor
tensor = (torch.rand(1, 3, 224, 224),)

# 分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total())

# 分析parameters
print(parameter_count_table(model))

Los resultados de salida del terminal son los siguientes, FLOP es 4089184256y la cantidad de parámetros del modelo es aproximadamente 25.6M(la cantidad de parámetros aquí es algo diferente de lo que calculé yo mismo, principalmente en el módulo BN, aquí solo se calculan dos parámetros de entrenamiento, no hay estadísticas y betados parámetros), puede ver el problema que mencioné en el sitio web oficial para más detalles . A través de la información impresa en el terminal, podemos encontrar que la capa BN no está incluida en el cálculo de FLOP, y la capa de agrupación también tiene una operación de suma común (descubrí que no hay una regulación uniforme al calcular FLOP, y el cálculo El proyecto FLOP visto en github básicamente cada uno es diferente, pero los resultados calculados son similares).gammamoving_meanmoving_var

Skipped operation aten::batch_norm 53 time(s)
Skipped operation aten::max_pool2d 1 time(s)
Skipped operation aten::add_ 16 time(s)
Skipped operation aten::adaptive_avg_pool2d 1 time(s)
FLOPs:  4089184256
| name                   | #elements or shape   |
|:-----------------------|:---------------------|
| model                  | 25.6M                |
|  conv1                 |  9.4K                |
|   conv1.weight         |   (64, 3, 7, 7)      |
|  bn1                   |  0.1K                |
|   bn1.weight           |   (64,)              |
|   bn1.bias             |   (64,)              |
|  layer1                |  0.2M                |
|   layer1.0             |   75.0K              |
|    layer1.0.conv1      |    4.1K              |
|    layer1.0.bn1        |    0.1K              |
|    layer1.0.conv2      |    36.9K             |
|    layer1.0.bn2        |    0.1K              |
|    layer1.0.conv3      |    16.4K             |
|    layer1.0.bn3        |    0.5K              |
|    layer1.0.downsample |    16.9K             |
|   layer1.1             |   70.4K              |
|    layer1.1.conv1      |    16.4K             |
|    layer1.1.bn1        |    0.1K              |
|    layer1.1.conv2      |    36.9K             |
|    layer1.1.bn2        |    0.1K              |
|    layer1.1.conv3      |    16.4K             |
|    layer1.1.bn3        |    0.5K              |
|   layer1.2             |   70.4K              |
|    layer1.2.conv1      |    16.4K             |
|    layer1.2.bn1        |    0.1K              |
|    layer1.2.conv2      |    36.9K             |
|    layer1.2.bn2        |    0.1K              |
|    layer1.2.conv3      |    16.4K             |
|    layer1.2.bn3        |    0.5K              |
|  layer2                |  1.2M                |
|   layer2.0             |   0.4M               |
|    layer2.0.conv1      |    32.8K             |
|    layer2.0.bn1        |    0.3K              |
|    layer2.0.conv2      |    0.1M              |
|    layer2.0.bn2        |    0.3K              |
|    layer2.0.conv3      |    65.5K             |
|    layer2.0.bn3        |    1.0K              |
|    layer2.0.downsample |    0.1M              |
|   layer2.1             |   0.3M               |
|    layer2.1.conv1      |    65.5K             |
|    layer2.1.bn1        |    0.3K              |
|    layer2.1.conv2      |    0.1M              |
|    layer2.1.bn2        |    0.3K              |
|    layer2.1.conv3      |    65.5K             |
|    layer2.1.bn3        |    1.0K              |
|   layer2.2             |   0.3M               |
|    layer2.2.conv1      |    65.5K             |
|    layer2.2.bn1        |    0.3K              |
|    layer2.2.conv2      |    0.1M              |
|    layer2.2.bn2        |    0.3K              |
|    layer2.2.conv3      |    65.5K             |
|    layer2.2.bn3        |    1.0K              |
|   layer2.3             |   0.3M               |
|    layer2.3.conv1      |    65.5K             |
|    layer2.3.bn1        |    0.3K              |
|    layer2.3.conv2      |    0.1M              |
|    layer2.3.bn2        |    0.3K              |
|    layer2.3.conv3      |    65.5K             |
|    layer2.3.bn3        |    1.0K              |
|  layer3                |  7.1M                |
|   layer3.0             |   1.5M               |
|    layer3.0.conv1      |    0.1M              |
|    layer3.0.bn1        |    0.5K              |
|    layer3.0.conv2      |    0.6M              |
|    layer3.0.bn2        |    0.5K              |
|    layer3.0.conv3      |    0.3M              |
|    layer3.0.bn3        |    2.0K              |
|    layer3.0.downsample |    0.5M              |
|   layer3.1             |   1.1M               |
|    layer3.1.conv1      |    0.3M              |
|    layer3.1.bn1        |    0.5K              |
|    layer3.1.conv2      |    0.6M              |
|    layer3.1.bn2        |    0.5K              |
|    layer3.1.conv3      |    0.3M              |
|    layer3.1.bn3        |    2.0K              |
|   layer3.2             |   1.1M               |
|    layer3.2.conv1      |    0.3M              |
|    layer3.2.bn1        |    0.5K              |
|    layer3.2.conv2      |    0.6M              |
|    layer3.2.bn2        |    0.5K              |
|    layer3.2.conv3      |    0.3M              |
|    layer3.2.bn3        |    2.0K              |
|   layer3.3             |   1.1M               |
|    layer3.3.conv1      |    0.3M              |
|    layer3.3.bn1        |    0.5K              |
|    layer3.3.conv2      |    0.6M              |
|    layer3.3.bn2        |    0.5K              |
|    layer3.3.conv3      |    0.3M              |
|    layer3.3.bn3        |    2.0K              |
|   layer3.4             |   1.1M               |
|    layer3.4.conv1      |    0.3M              |
|    layer3.4.bn1        |    0.5K              |
|    layer3.4.conv2      |    0.6M              |
|    layer3.4.bn2        |    0.5K              |
|    layer3.4.conv3      |    0.3M              |
|    layer3.4.bn3        |    2.0K              |
|   layer3.5             |   1.1M               |
|    layer3.5.conv1      |    0.3M              |
|    layer3.5.bn1        |    0.5K              |
|    layer3.5.conv2      |    0.6M              |
|    layer3.5.bn2        |    0.5K              |
|    layer3.5.conv3      |    0.3M              |
|    layer3.5.bn3        |    2.0K              |
|  layer4                |  15.0M               |
|   layer4.0             |   6.0M               |
|    layer4.0.conv1      |    0.5M              |
|    layer4.0.bn1        |    1.0K              |
|    layer4.0.conv2      |    2.4M              |
|    layer4.0.bn2        |    1.0K              |
|    layer4.0.conv3      |    1.0M              |
|    layer4.0.bn3        |    4.1K              |
|    layer4.0.downsample |    2.1M              |
|   layer4.1             |   4.5M               |
|    layer4.1.conv1      |    1.0M              |
|    layer4.1.bn1        |    1.0K              |
|    layer4.1.conv2      |    2.4M              |
|    layer4.1.bn2        |    1.0K              |
|    layer4.1.conv3      |    1.0M              |
|    layer4.1.bn3        |    4.1K              |
|   layer4.2             |   4.5M               |
|    layer4.2.conv1      |    1.0M              |
|    layer4.2.bn1        |    1.0K              |
|    layer4.2.conv2      |    2.4M              |
|    layer4.2.bn2        |    1.0K              |
|    layer4.2.conv3      |    1.0M              |
|    layer4.2.bn3        |    4.1K              |
|  fc                    |  2.0M                |
|   fc.weight            |   (1000, 2048)       |
|   fc.bias              |   (1000,)            |

Process finished with exit code 0

Para obtener más métodos de uso, puede ir al proyecto original para ver la documentación de uso.

Supongo que te gusta

Origin blog.csdn.net/qq_37541097/article/details/117471650
Recomendado
Clasificación