Método de cálculo de parámetros del modelo de red neuronal convolucional CNN (versión empírica)

Si revisas el pasado y aprendes lo nuevo, ¡puedes convertirte en maestro!

1. Materiales de referencia

Cómo utilizar torchsummary y torchstat y analizar los resultados

2. Introducción relacionada

1. Monto del parámetro y monto del cálculo

Blog de referencia:Fácil de entender los parámetros y la cantidad de cálculo del modelo de red neuronal convolucional CNN

3. Métodos comunes para calcular los parámetros del modelo.

Método recomendado:torchstat

1.torchstat

Función: Ver el tamaño y las operaciones de punto flotante del modelo.

# 安装torchstat
pip install torchstat
import torch
from torchvision.models import vgg16
from torchstat import stat


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
net = vgg16()

stat(net, (3, 224, 224))

Resultado

[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!
[Memory]: AdaptiveAvgPool2d is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
        module name  input shape output shape       params memory(MB)              MAdd             Flops   MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0        features.0    3 224 224   64 224 224       1792.0      12.25     173,408,256.0      89,915,392.0     609280.0   12845056.0       2.69%   13454336.0
1        features.1   64 224 224   64 224 224          0.0      12.25       3,211,264.0       3,211,264.0   12845056.0   12845056.0       0.22%   25690112.0
2        features.2   64 224 224   64 224 224      36928.0      12.25   3,699,376,128.0   1,852,899,328.0   12992768.0   12845056.0       9.99%   25837824.0
3        features.3   64 224 224   64 224 224          0.0      12.25       3,211,264.0       3,211,264.0   12845056.0   12845056.0       0.18%   25690112.0
4        features.4   64 224 224   64 112 112          0.0       3.06       2,408,448.0       3,211,264.0   12845056.0    3211264.0       2.51%   16056320.0
5        features.5   64 112 112  128 112 112      73856.0       6.12   1,849,688,064.0     926,449,664.0    3506688.0    6422528.0       4.81%    9929216.0
6        features.6  128 112 112  128 112 112          0.0       6.12       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.11%   12845056.0
7        features.7  128 112 112  128 112 112     147584.0       6.12   3,699,376,128.0   1,851,293,696.0    7012864.0    6422528.0       8.55%   13435392.0
8        features.8  128 112 112  128 112 112          0.0       6.12       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.08%   12845056.0
9        features.9  128 112 112  128  56  56          0.0       1.53       1,204,224.0       1,605,632.0    6422528.0    1605632.0       0.90%    8028160.0
10      features.10  128  56  56  256  56  56     295168.0       3.06   1,849,688,064.0     925,646,848.0    2786304.0    3211264.0       4.41%    5997568.0
11      features.11  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.05%    6422528.0
12      features.12  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       8.55%    8782848.0
13      features.13  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.05%    6422528.0
14      features.14  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       8.41%    8782848.0
15      features.15  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.04%    6422528.0
16      features.16  256  56  56  256  28  28          0.0       0.77         602,112.0         802,816.0    3211264.0     802816.0       0.47%    4014080.0
17      features.17  256  28  28  512  28  28    1180160.0       1.53   1,849,688,064.0     925,245,440.0    5523456.0    1605632.0       4.56%    7129088.0
18      features.18  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.02%    3211264.0
19      features.19  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       8.84%   12650496.0
20      features.20  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.02%    3211264.0
21      features.21  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       8.48%   12650496.0
22      features.22  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.02%    3211264.0
23      features.23  512  28  28  512  14  14          0.0       0.38         301,056.0         401,408.0    1605632.0     401408.0       0.25%    2007040.0
24      features.24  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       2.77%   10242048.0
25      features.25  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.01%     802816.0
26      features.26  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       2.39%   10242048.0
27      features.27  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.01%     802816.0
28      features.28  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       2.38%   10242048.0
29      features.29  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.01%     802816.0
30      features.30  512  14  14  512   7   7          0.0       0.10          75,264.0         100,352.0     401408.0     100352.0       0.08%     501760.0
31          avgpool  512   7   7  512   7   7          0.0       0.10               0.0               0.0          0.0          0.0       0.08%          0.0
32     classifier.0        25088         4096  102764544.0       0.02     205,516,800.0     102,760,448.0  411158528.0      16384.0      15.48%  411174912.0
33     classifier.1         4096         4096          0.0       0.02           4,096.0           4,096.0      16384.0      16384.0       0.04%      32768.0
34     classifier.2         4096         4096          0.0       0.02               0.0               0.0          0.0          0.0       0.04%          0.0
35     classifier.3         4096         4096   16781312.0       0.02      33,550,336.0      16,777,216.0   67141632.0      16384.0       1.96%   67158016.0
36     classifier.4         4096         4096          0.0       0.02           4,096.0           4,096.0      16384.0      16384.0       0.01%      32768.0
37     classifier.5         4096         4096          0.0       0.02               0.0               0.0          0.0          0.0       0.01%          0.0
38     classifier.6         4096         1000    4097000.0       0.00       8,191,000.0       4,096,000.0   16404384.0       4000.0       0.49%   16408384.0
total                                          138357544.0     109.39  30,958,666,264.0  15,503,489,024.0   16404384.0       4000.0     100.00%  783170624.0
============================================================================================================================================================
Total params: 138,357,544
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 109.39MB
Total MAdd: 30.96GMAdd
Total Flops: 15.5GFlops
Total MemR+W: 746.89MB

Descripción de parámetros

  • params: La cantidad de parámetros de la red.
  • memory: El espacio ocupado por los parámetros intermedios generados por el modelo durante el cálculo.
  • Flops: Operaciones de coma flotante realizadas por la red.
  • MAdd: El número de operaciones de multiplicación y suma completadas por la red. Una multiplicación y suma = una multiplicación + una suma, por lo que se puede considerar aproximadamente que Flops ≈2*MAd.
  • MemRead: El tamaño leído de la memoria cuando la red está funcionando.
  • MemWrite: El tamaño escrito en la memoria cuando la red está en ejecución.
  • MemR+WMemR+W = MemRead + MemWrite

Por ejemplo:

float32 es un tipo de punto flotante de precisión simple de 32 bits que ocupa 4B bytes, 1 M B = 2 10 K B = 2 20 B 1MB=2^{10}KB= 2^ {20}B 1MB=210KB=220B

        module name  input shape output shape       params memory(MB)              MAdd             Flops   MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0        features.0    3 224 224   64 224 224       1792.0      12.25     173,408,256.0      89,915,392.0     609280.0   12845056.0       2.69%   13454336.0

params参数量: D k ∗ D k ∗ C i ∗ C o + C o = 3 ∗ 3 ∗ 3 ∗ 64 + 64 = 1 , 792 D_k * D_k * C_i * C_o + C_o = 3*3*3*64+64=1,792 DkDkCiCo+Co=33364+64=1,792;
Número de miembros de la memoria: 224 * 224 * 64 * 4 B = 12,25 MB 224 * 224 * 64 * 4B=12,25 MB224224644B=12,25MB.

2.torchsummary

Función: Ver la estructura del modelo y las dimensiones de entrada y salida.

# 安装torchsummary
pip install torchsummary
torchsummary.summary(model, input_size, batch_size=-1, device="cuda")

Explicación de parámetros

  • model: modelo de pytorch, debe heredar de nn.Module.
  • input_size: Tamaño de entrada del modelo, la forma es C, H, W.
  • batch_size: El valor predeterminado es -1, el tamaño de lote que se muestra al mostrar la forma de cada capa de salida del modelo.
  • device: "cuda" o "cpu", dispositivo predeterminado = 'cuda'.
import torch
from torchvision.models import vgg16
from torchsummary import summary


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
net = vgg16()

net.to(DEVICE)
print(summary(net, input_size=(3, 224, 224),device=DEVICE))

Resultado

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-32            [-1, 512, 7, 7]               0
           Linear-33                 [-1, 4096]     102,764,544
             ReLU-34                 [-1, 4096]               0
          Dropout-35                 [-1, 4096]               0
           Linear-36                 [-1, 4096]      16,781,312
             ReLU-37                 [-1, 4096]               0
          Dropout-38                 [-1, 4096]               0
           Linear-39                 [-1, 1000]       4,097,000
================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.78
Params size (MB): 527.79
Estimated Total Size (MB): 747.15
----------------------------------------------------------------
None

3.profile

# 安装thop
pip install thop
import torch
from torchvision.models import vgg16
from thop import profile
from thop import clever_format


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
net = vgg16()

x = torch.rand(1, 3, 224, 224)

flops, params = profile(net, inputs=(x, ))
# print(flops, params)

macs, params = clever_format([flops, params], "%.3f")
print(macs, params)

Resultado

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
15.484G 138.358M

4.ptflops

# 安装ptflops
pip install ptflops
import torch
from torchvision.models import vgg16
from ptflops import get_model_complexity_info

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
net = vgg16()

flops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,print_per_layer_stat=True)
print("%s %s" % (flops, params))

Resultado

VGG(
  138.36 M, 100.000% Params, 15.5 GMac, 99.873% MACs, 
  (features): Sequential(
    14.71 M, 10.635% Params, 15.38 GMac, 99.076% MACs, 
    (0): Conv2d(1.79 k, 0.001% Params, 89.92 MMac, 0.579% MACs, 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(0, 0.000% Params, 3.21 MMac, 0.021% MACs, inplace=True)
    (2): Conv2d(36.93 k, 0.027% Params, 1.85 GMac, 11.936% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(0, 0.000% Params, 3.21 MMac, 0.021% MACs, inplace=True)
    (4): MaxPool2d(0, 0.000% Params, 3.21 MMac, 0.021% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(73.86 k, 0.053% Params, 926.45 MMac, 5.968% MACs, 64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(0, 0.000% Params, 1.61 MMac, 0.010% MACs, inplace=True)
    (7): Conv2d(147.58 k, 0.107% Params, 1.85 GMac, 11.926% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(0, 0.000% Params, 1.61 MMac, 0.010% MACs, inplace=True)
    (9): MaxPool2d(0, 0.000% Params, 1.61 MMac, 0.010% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(295.17 k, 0.213% Params, 925.65 MMac, 5.963% MACs, 128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(0, 0.000% Params, 802.82 KMac, 0.005% MACs, inplace=True)
    (12): Conv2d(590.08 k, 0.426% Params, 1.85 GMac, 11.921% MACs, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(0, 0.000% Params, 802.82 KMac, 0.005% MACs, inplace=True)
    (14): Conv2d(590.08 k, 0.426% Params, 1.85 GMac, 11.921% MACs, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(0, 0.000% Params, 802.82 KMac, 0.005% MACs, inplace=True)
    (16): MaxPool2d(0, 0.000% Params, 802.82 KMac, 0.005% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(1.18 M, 0.853% Params, 925.25 MMac, 5.960% MACs, 256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(0, 0.000% Params, 401.41 KMac, 0.003% MACs, inplace=True)
    (19): Conv2d(2.36 M, 1.706% Params, 1.85 GMac, 11.918% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(0, 0.000% Params, 401.41 KMac, 0.003% MACs, inplace=True)
    (21): Conv2d(2.36 M, 1.706% Params, 1.85 GMac, 11.918% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(0, 0.000% Params, 401.41 KMac, 0.003% MACs, inplace=True)
    (23): MaxPool2d(0, 0.000% Params, 401.41 KMac, 0.003% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(2.36 M, 1.706% Params, 462.52 MMac, 2.980% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(0, 0.000% Params, 100.35 KMac, 0.001% MACs, inplace=True)
    (26): Conv2d(2.36 M, 1.706% Params, 462.52 MMac, 2.980% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(0, 0.000% Params, 100.35 KMac, 0.001% MACs, inplace=True)
    (28): Conv2d(2.36 M, 1.706% Params, 462.52 MMac, 2.980% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(0, 0.000% Params, 100.35 KMac, 0.001% MACs, inplace=True)
    (30): MaxPool2d(0, 0.000% Params, 100.35 KMac, 0.001% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(0, 0.000% Params, 25.09 KMac, 0.000% MACs, output_size=(7, 7))
  (classifier): Sequential(
    123.64 M, 89.365% Params, 123.65 MMac, 0.797% MACs, 
    (0): Linear(102.76 M, 74.275% Params, 102.76 MMac, 0.662% MACs, in_features=25088, out_features=4096, bias=True)
    (1): ReLU(0, 0.000% Params, 4.1 KMac, 0.000% MACs, inplace=True)
    (2): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.5, inplace=False)
    (3): Linear(16.78 M, 12.129% Params, 16.78 MMac, 0.108% MACs, in_features=4096, out_features=4096, bias=True)
    (4): ReLU(0, 0.000% Params, 4.1 KMac, 0.000% MACs, inplace=True)
    (5): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.5, inplace=False)
    (6): Linear(4.1 M, 2.961% Params, 4.1 MMac, 0.026% MACs, in_features=4096, out_features=1000, bias=True)
  )
)
15.52 GMac 138.36 M

5. Función personalizada (ejemplo)

Cálculo simple de parámetros del modelo GET de 2 pasos y derivación del tamaño de entrada a medida que cambia el tamaño de la convolución

import torch
from torchvision.models import vgg16



def print_networks(model, verbose):
    """Print the total number of parameters in the network and (if verbose) network architecture

    Parameters:
        model (torch.nn.Module): 要打印的PyTorch模型
        verbose (bool): 是否打印模型的网络结构
    """
    # 打印模型总参数数量
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters in the model: {
      
      num_params}, {
      
      num_params / 1e6:.3f} M")

    # 如果需要,打印模型的网络结构
    if verbose:
        print(model)

        
net = vgg16()
print_networks(net, verbose=True)

6. Función personalizada (ejemplo)

pytorch calcula los parámetros del modelo
[ pytorch ] Uso básico 丨 5. Aprendizaje de parámetros/peso del modelo 丨

Función: utilice los métodos propios de la antorcha para personalizar las funciones.

'''方法1,自定义函数 参考自 https://blog.csdn.net/qq_33757398/article/details/109210240'''
def model_structure(model):
    blank = ' '
    print('-' * 90)
    print('|' + ' ' * 11 + 'weight name' + ' ' * 10 + '|' \
          + ' ' * 15 + 'weight shape' + ' ' * 15 + '|' \
          + ' ' * 3 + 'number' + ' ' * 3 + '|')
    print('-' * 90)
    num_para = 0
    type_size = 1  # 如果是浮点数就是4

    for index, (key, w_variable) in enumerate(model.named_parameters()):
        if len(key) <= 30:
            key = key + (30 - len(key)) * blank
        shape = str(w_variable.shape)
        if len(shape) <= 40:
            shape = shape + (40 - len(shape)) * blank
        each_para = 1
        for k in w_variable.shape:
            each_para *= k
        num_para += each_para
        str_num = str(each_para)
        if len(str_num) <= 10:
            str_num = str_num + (10 - len(str_num)) * blank

        print('| {} | {} | {} |'.format(key, shape, str_num))
    print('-' * 90)
    print('The total number of parameters: ' + str(num_para))
    print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000))
    print('-' * 90)

model_structure(net)

Resultado

------------------------------------------------------------------------------------------
|           weight name          |               weight shape               |   number   |
------------------------------------------------------------------------------------------
| features.0.weight              | torch.Size([64, 3, 3, 3])                | 1728       |
| features.0.bias                | torch.Size([64])                         | 64         |
| features.2.weight              | torch.Size([64, 64, 3, 3])               | 36864      |
| features.2.bias                | torch.Size([64])                         | 64         |
| features.5.weight              | torch.Size([128, 64, 3, 3])              | 73728      |
| features.5.bias                | torch.Size([128])                        | 128        |
| features.7.weight              | torch.Size([128, 128, 3, 3])             | 147456     |
| features.7.bias                | torch.Size([128])                        | 128        |
| features.10.weight             | torch.Size([256, 128, 3, 3])             | 294912     |
| features.10.bias               | torch.Size([256])                        | 256        |
| features.12.weight             | torch.Size([256, 256, 3, 3])             | 589824     |
| features.12.bias               | torch.Size([256])                        | 256        |
| features.14.weight             | torch.Size([256, 256, 3, 3])             | 589824     |
| features.14.bias               | torch.Size([256])                        | 256        |
| features.17.weight             | torch.Size([512, 256, 3, 3])             | 1179648    |
| features.17.bias               | torch.Size([512])                        | 512        |
| features.19.weight             | torch.Size([512, 512, 3, 3])             | 2359296    |
| features.19.bias               | torch.Size([512])                        | 512        |
| features.21.weight             | torch.Size([512, 512, 3, 3])             | 2359296    |
| features.21.bias               | torch.Size([512])                        | 512        |
| features.24.weight             | torch.Size([512, 512, 3, 3])             | 2359296    |
| features.24.bias               | torch.Size([512])                        | 512        |
| features.26.weight             | torch.Size([512, 512, 3, 3])             | 2359296    |
| features.26.bias               | torch.Size([512])                        | 512        |
| features.28.weight             | torch.Size([512, 512, 3, 3])             | 2359296    |
| features.28.bias               | torch.Size([512])                        | 512        |
| classifier.0.weight            | torch.Size([4096, 25088])                | 102760448  |
| classifier.0.bias              | torch.Size([4096])                       | 4096       |
| classifier.3.weight            | torch.Size([4096, 4096])                 | 16777216   |
| classifier.3.bias              | torch.Size([4096])                       | 4096       |
| classifier.6.weight            | torch.Size([1000, 4096])                 | 4096000    |
| classifier.6.bias              | torch.Size([1000])                       | 1000       |
------------------------------------------------------------------------------------------
The total number of parameters: 138357544
The parameters of Model VGG: 138.357544M
------------------------------------------------------------------------------------------

Supongo que te gusta

Origin blog.csdn.net/m0_37605642/article/details/134128004
Recomendado
Clasificación