onnx2Tensorflow实战

上文介绍了将pytorch模型转为onnx模型的方法,本文将介绍得到.onnx文件后通过onnx_tf将其转为tensorflow文件的方法。

注意事项

onnx_tf中明确说明了onnx只支持tensorflow>=1.15.0,但是由于安卓端的tensorflow更新的比较慢,一般只支持版本较低的tensorflow,如demo中只支持tensorflow 1.12.0(或1.13.0),因此由onnx转换成tensorflow的过程中要转成tensorflow 1.12.0可使用的pb文件。

  1. 若使用tensorflow 1.15.0环境去输出由tensorflow 1.12.0环境下生成的pb文件是可以正常输出的。
  2. 若使用tensorflow 1.12.0环境去输出由tensorflow 1.15.0环境下生成的pb文件,则会产生以下报错:在这里插入图片描述
  3. 若直接使用官方onnx_tf和tensorflow 1.12.0生成pb文件则会产生以下两个报错:在这里插入图片描述
    在这里插入图片描述

因此需要将onnx_tf的代码进行一些修改。
如果安装onnx_tf时使用pip install onnx_tf命令,则将anaconda3\Lib\site-packages\onnx_tf\handlers\backend\is_inf.py文件下的 @tf_func(tf.math.is_inf)改为 @tf_func(tf.is_inf),将anaconda3\Lib\site-packages\onnx_tf\handlers\backend\scatter_nd.py文件下的@tf_func(tf.tensor_scatter_nd_update)改为@tf_func(tf.scatter_nd_update)
如果去官网下载安装onnx_tf,则将onnx-tensorflow\onnx_tf\handlers\backend\is_inf.py文件下的@tf_func(tf.math.is_inf)改为 @tf_func(tf.is_inf),将onnx-tensorflow\onnx_tf\handlers\backend\scatter_nd.py文件下的@tf_func(tf.tensor_scatter_nd_update)改为@tf_func(tf.scatter_nd_update)

代码

onnx2pb

# venv-tf1.12
import onnx
import numpy as np
from onnx_tf.backend import prepare

# 给定输入图片或者随机输入,尺寸要跟.onnx模型生成时dummy_input一样
img = np.load('random.npy')
# img = img.reshape([1, 3, 300, 300])

# 导入onnx到tensorflow中,并获得输出
model = onnx.load('model_005000.onnx')
# 这里必须strict=False,不然生成的pb文件输出会报错
tf_rep = prepare(model, strict=False)
onnx_output = tf_rep.run(img)
print("onnx-tensorflow output: \n",onnx_output)

# 将onnx-tensorflow模型导出成pb格式
name = "model_005000.pb"
tf_rep.export_graph(name)

save_npy

输入图片random.npy可以通过运行以下代码得到:

import numpy as np

img = np.random.rand(1, 3, 300, 300)
np.save(r'E:\Pythonworkspace\pth2pb\random', img) #保存的路径,random表示文件名

test_pb

验证pb文件:

import tensorflow as tf
import numpy as np
# from PIL import Image

model_path = 'model_005000.pb'

with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()

    with open(model_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        inp = sess.graph.get_tensor_by_name('actual_input_1:0')   #以下节点通过运行print_tensorname得到
        out0 = sess.graph.get_tensor_by_name('output1:0')
        out1 = sess.graph.get_tensor_by_name('73:0')
        out2 = sess.graph.get_tensor_by_name('74:0')

        img = np.load('random.npy')
        # img = img.reshape([1, 3, 300, 300])
        pre_num = sess.run([out0, out1, out2], feed_dict={inp: img})
        print(pre_num)

print_tensorname

只有能打印出来的节点才能作为输入输出!注意tensor有两种,一种是保存固定值的Const节点,是各个层训练完得到的固定权重偏置等值;一种是会因输入不同而得到不同数值的变量节点,也是我们所需要的tensor。这里输出的tensor名字只是一半,一般来说后一半都是0或者1,如“481:0”、“Add_51:0”等。
打印pb图中的节点:

import tensorflow as tf

model_name = 'model_005000.pb'
with tf.gfile.GFile(model_name , 'rb') as f:
    # 使用tf.GraphDef()定义一个空的Graph
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    # 把当前流图读入graph_def中
    tf.import_graph_def(graph_def, name='')
# 打印所有tensor名称
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
    print(tensor_name,'\n')

结果如下图所示:
节点名称

有问题欢迎在评论区留言,本人水平有限,有错误希望大家指正。

猜你喜欢

转载自blog.csdn.net/Creama_/article/details/105048096