Pytorch 网络结构可视化

安装

conda install graphviz
conda install tensorwatch


载入库

import sys
import torch
import tensorwatch as tw
import torchvision.models


网络结构可视化

alexnet_model = torchvision.models.alexnet()
tw.draw_model(alexnet_model, [1, 3, 224, 224])


载入alexnet,draw_model函数需要传入三个参数,第一个为model,第二个参数为input_shape,第三个参数为orientation,可以选择'LR'或者'TB',分别代表左右布局与上下布局。
在notebook中,执行完上面的代码会显示如下的图,将网络的结构及各个层的name和shape进行了可视化。

统计网络参数

通过model_stats方法统计各层的参数情况。
tw.model_stats(alexnet_model, [1, 3, 224, 224])
alexnet_model.features
alexnet_model.classifier
来源:https://zhuanlan.zhihu.com/p/66320870

猜你喜欢

转载自www.cnblogs.com/jeshy/p/11126077.html