用torchinfo库的sumary函数可以打印模型信息,示例如下。注意输入张量中不需要batch维度。
import numpy as np
import torch
import torchvision
import torchinfo
model = torchvision.models.resnet50(pretrained=True)
torchinfo.summary(model, (3, 224, 224), batch_dim=0,
col_names=('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose=1
)