pytorch模型可视化,torchviz和tensorboardX方式

torchviz方式:

1 from torchviz import make_dot
2 inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH).requires_grad_(True)  #有.requires_grad_(True)显示输入形状
3 model = vgg()  #model是vgg类的实例
4 vis_graph = make_dot(model(inputs_fake), params=dict(list(model.named_parameters()) + [('x', inputs_fake)]))
5 vis_graph.view()

tensorboardX方式:

from tensorboardX import SummaryWriter
inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH)
model = vgg() with SummaryWriter(comment
='vgg') as w: w.add_graph(model, (inputs_fake,))

torchviz生成一个pdf,pdf怎样命名还不知道,或许只能默认命名。

猜你喜欢

转载自www.cnblogs.com/zhangziyan/p/11909771.html
今日推荐