pytorch借助tensorboard实现模型可视化

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_19332527/article/details/80791160

补充 : 刚发现貌似sqrt操作是不支持的

python库:

pytorch(>=0.3) , onnx, tensorboardX

原理:

Open Neural Network Exchange (ONNX)是开放生态系统的第一步,它使人工智能开发人员可以在项目的发展过程中选择合适的工具;ONNX为AI models提供了一种开源格式。它定义了一个可以扩展的计算图模型,同时也定义了内置操作符和标准数据类型。最初我们关注的是推理(评估)所需的能力。

Caffe2, PyTorch, Microsoft Cognitive Toolkit, Apache MXNet 和其他工具都在对ONNX进行支持。在不同的框架之间实现互操作性,并简化从研究到产品化的过程,将提高人工智能社区的创新速度。

简单来说就是借助onnx将pytorch的模型存为model.proto的文件,然后借助于tensorboardX这个工具将model.proto转换为tensorboar的graph.

代码:

#对于pytorch0.3以上的版本
import tensorboardX 
import torch
from torchvision.models import resnet34
import torch.onnx

x=torch.autograd.Variable(torch.rand(1,3,224,224)) #随便定义一个输入
model=resnet34()

proto=torch.onnx.export(model,x,"resnet34.proto",verbose=True) #将model的结构和参数全部保存为 resnet32.proto

writer=tensorboardX.SummaryWriter("./logs/")  #定义一个tensorboardX的写对象 
writer.add_graph_onnx("./resnet34.proto")  #将proto格式的文件转换为tensorboard中的graph

对于pytorch 0.2来说可以直接来画:

import tensorboardX 
import torch
from torchvision.models import resnet34
import torch.onnx

x=torch.autograd.Variable(torch.rand(1,3,224,224)) #随便定义一个输入
model=resnet34()
writer=tensorboardX.SummaryWriter("./logs/")  #定义一个tensorboardX的写对象 
writer.add_graph(model,x,verbose=True)  #将proto格式的文件转换为tensorboard中的graph

效果如下 ,确实有点丑,不如tensorflow那样五颜六色,也没有更加详细的操作:

拉近的图片:

补充,刚才有人说好像max_pool2d是不支持的,我自己的测试时可以的,建议检查一下tensorboardX和ONNX的版本,代码如下: 我的tensorboardX版本是1.4的,onnx版本是1.3.0

import torch
import  torch.nn.functional as F
import torch.onnx
import tensorboardX

class ResNet(nn.Module):

	def __init__(self, block, layers, num_classes=1000):
		super(ResNet, self).__init__()
	def forward(self, x):
		#这儿就是我加的操作
		x=F.max_pool2d(x,kernel_size=7)
		return x

def resnet50():
	"""Constructs a ResNet-50 model.

	Args:
		pretrained (bool): If True, returns a model pre-trained on ImageNet
	"""
	model = ResNet(Bottleneck, [3, 4, 6, 3])
	return model

if __name__=="__main__":
	x=torch.autograd.Variable(torch.rand(1,3,224,224)) #随便定义一个输入
	model=resnet50()
	 
	proto=torch.onnx.export(model,x,"resnet50.proto",verbose=True) #将model的结构和参数全部保存为 resnet32.proto
	 
	writer=tensorboardX.SummaryWriter("./logs/")  #定义一个tensorboardX的写对象 
	writer.add_graph_onnx("./resnet50.proto")

猜你喜欢

转载自blog.csdn.net/qq_19332527/article/details/80791160