pytorch打印模型信息——torchinfo

 用torchinfo库的sumary函数可以打印模型信息,示例如下。注意输入张量中不需要batch维度。

import numpy as np

import torch
import torchvision
import torchinfo

model = torchvision.models.resnet50(pretrained=True)
torchinfo.summary(model, (3, 224, 224), batch_dim=0,
                  col_names=('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose=1
        )

猜你喜欢

转载自blog.csdn.net/wxyczhyza/article/details/128118418