A simple method to count the parameters of pytorch model

nelememt() function

Tensor.nelement()->referenced from Tensor.numel()->referenced from torch.numel( input )
The three functions are the same
Returns the total number of elements in the input tensor.
Returns the total number of elements in the input tensor.

Using the above function just can count the number of parameters of the model

parameters() function

Module. parameters (recurse=True)
Returns an iterator over module parameters.
recurse ( bool ) – if True , then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module
. It is tensor, which acts as a parameter in the module
If recurse is True: return the parameters of the current module and sub-module
If recurse is False: only return the member parameters of the current module, not the parameters of the sub-module

Statistics module parameters

from torchvision.models import resnet50
model = resnet50()
total = sum([param.nelement() for param in model.parameters()])
print("total = ", total)
Output:
total = 25557032

Guess you like

Origin blog.csdn.net/qq_55796594/article/details/128982501