pytorch计算Parameter和FLOP

版权声明:原创作品,欢迎转发!转发附上链接 https://blog.csdn.net/qq_26369907/article/details/89857021

深度学习中,模型训练完后,查看模型的参数量和浮点计算量,在此记录下:

1 THOP

在pytorch中有现成的包thop用于计算参数数量和FLOP,首先安装thop:

pip install thop

注意安装thop时可能出现如下错误:
在这里插入图片描述
解决方法:

pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git  # 下载源码安装

使用方法如下:

from torchvision.models import resnet50  # 引入ResNet50模型
from thop import profile

model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224))  #  profile(模型,输入数据)

对于自己构建的函数也一样,例如shuffleNetV2

    from thop import profile
    from utils.ShuffleNetV2 import shufflenetv2  # 导入shufflenet2 模块
    import torch 
    
    model_shuffle = shufflenetv2(width_mult=0.5)
    model = torch.nn.DataParallel(model_shuffle)   # 调用shufflenet2 模型,该模型为自己定义的
    flop, para = profile(model, input_size=(1, 3, 224, 224),)  
    print("%.2fM" % (flop/1e6), "%.2fM" % (para/1e6))

更多细节,可参考thop GitHub链接: https://github.com/Lyken17/pytorch-OpCounter

2 计算参数

pytorch本身带有计算参数的方法

    from thop import profile
    from utils.ShuffleNetV2 import shufflenetv2  # 导入shufflenet2 模块
    import torch 
    
    model_shuffle = shufflenetv2(width_mult=0.5)
    model = torch.nn.DataParallel(model_shuffle)
    total = sum([param.nelement() for param in model.parameters()])
    print("Number of parameter: %.2fM" % (total / 1e6))

猜你喜欢

转载自blog.csdn.net/qq_26369907/article/details/89857021
今日推荐