深度学习常用代码(Pytorch)

1.获取模型的flops和parameters

  • 方法1:get_model_complexity_info

安装基本库:
​​​​​​​pip install ptflops
pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git

import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
 
with torch.cuda.device(0):
  net = models.vgg16()
  macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
  print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
  print('{:<30}  {:<8}'.format('Number of parameters: ', params))
  • 方法2:profile

import torch
 
from thop import profile
 
from   torchvision.models import resnet50
model = resnet50()
 
input = torch.randn(1, 3, 224, 224)
 
flops,   params = profile(model, inputs=(input, ))
print('flops:', flops / (1000 ** 3))
print('params:', params / (1000 ** 2))

  • 方法3:FlopCountAnalysis, parameter_count_table
import torch
from torchvision.models import resnet50
from fvcore.nn import FlopCountAnalysis, parameter_count_table

# 创建resnet50网络
model = resnet50(num_classes=1000)

# 创建输入网络的tensor
tensor = (torch.rand(1, 3, 224, 224),)

# 分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total() / (1000 ** 3))
import torch
from torchvision.models import resnet50
from fvcore.nn import FlopCountAnalysis, parameter_count_table

# 创建resnet50网络
model = resnet50(num_classes=1000)

# 分析parameters
print(parameter_count_table(model))
  • 说明:
    • get_model_complexity_info不能自定义输入,其他两个可以

猜你喜欢

转载自blog.csdn.net/qq_35759272/article/details/127728732
今日推荐