1.获取模型的flops和parameters
- 方法1:get_model_complexity_info
安装基本库:
pip install ptflops
pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
net = models.vgg16()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
- 方法2:profile
import torch
from thop import profile
from torchvision.models import resnet50
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print('flops:', flops / (1000 ** 3))
print('params:', params / (1000 ** 2))
- 方法3:FlopCountAnalysis, parameter_count_table
import torch
from torchvision.models import resnet50
from fvcore.nn import FlopCountAnalysis, parameter_count_table
# 创建resnet50网络
model = resnet50(num_classes=1000)
# 创建输入网络的tensor
tensor = (torch.rand(1, 3, 224, 224),)
# 分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total() / (1000 ** 3))
import torch
from torchvision.models import resnet50
from fvcore.nn import FlopCountAnalysis, parameter_count_table
# 创建resnet50网络
model = resnet50(num_classes=1000)
# 分析parameters
print(parameter_count_table(model))
- 说明:
- get_model_complexity_info不能自定义输入,其他两个可以