Netron 可视化Pytorh模型架构
前言
当训练别人的模型并加入DDP时,发现模型中有部分参数没有被使用而报错。而通过Print输出模型架构又太冗余,且不直观。因此在网上寻找一种可以可视化模型架构的工具,要求该工具可以实现对Pytorch模型的可视化,且该工具处于活跃状态(更新周期短),并且有大量用户使用(Star10k+)。
Netron1恰好符合上述需求,更新周期短,最近更新周期 3 Hours前,且issue有回复,star 20k+,完美符合。而Tensorwatch2虽然是微软开发的,但距离上次更新已经 3 years前,star 3k+,显然已经属于被淘汰的工具箱。其它工具箱如torchvis等也不符合上述需求。
Preparatory works
在使用Netron之前,需要有如下准备:
- 搭建好完整的模型架构
- 将搭建好的模型进行torch.save保存,存储为pth,pt格式。
Netron的安装
需要安装netron和onnx,因为netron目前仅支持如下格式:
- ONNX (.onnx, .pb)
- Keras (.h5, .keras)
-CoreML (.mlmodel) - TensorFlow Lite (.tflite)
pip install netron
pip install onnx
orconda install ***
Netron的使用
导出onnx格式的pth文件
import torch.onnx
from torch.autograd import Variable
from torchvision.models import resnet18 # 以 resnet18 为例
myNet = resnet18() # 实例化 resnet18
x = torch.randn(16, 3, 40, 40) # 随机生成一个输入
modelData = "demo.pth" # 定义模型数据保存的路径
torch.onnx.export(myNet, x, modelData) # 将 pytorch 模型以 onnx 格式导出并保存
前往NETRON GITHUB下载对应mac或windows对应版本的nerton软件,运行即可。
总结: 发现导出图后也没那么直观,还是用ipad手画结构推导更方便。