tensorflow 训练保存模型3 PB格式

上回说到图看不明白。
所以有了下面的方法:
首先读取刚刚的ckpt文件,保存为pb格式(当然训练的时候直接保存也么有问题)
保存Softmax应该就是把计算Softmax所有必须的变量结构都保存下来,无关的就不要了

import os
ckpt_dir =  "./pb_dir"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
    
with tf.Session() as sess:
    saver.restore(sess,"ckpt_dir/Test1.ckpt-19")
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['Softmax'])
    with tf.gfile.FastGFile(ckpt_dir+'/Test1.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

运行结果,显示把6组变量froze为const op
在这里插入图片描述
然后再开个工程,读取pb,保存到tensorboard中

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
from tensorflow.python.platform import gfile
with tf.gfile.FastGFile('pb_dir/Test1.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
    writer = tf.summary.FileWriter("logs/", sess.graph)
    NetOut=sess.run("Softmax:0",feed_dict={"Placeholder:0":mnist.test.images,"Placeholder_2:0":1.0})

同样的方法,打开tensorboard
在这里插入图片描述
这下清楚多了吧!每个节点的名字也能够直接在tensorboard上看到
再测试下准确率:

import numpy as np

Prediction = NetOut.argmax(axis=1)#找到最大值位置
TestLabels = mnist.test.labels.argmax(axis=1)#找到最大值位置
err = 0
for i in range(Prediction.shape[0]):
    if (Prediction[i]!=TestLabels[i]):
        err = err+1
err=err/Prediction.shape[0]
acc = 1-err
print("Acc=",acc)

非常OK!

猜你喜欢

转载自blog.csdn.net/masbbx123/article/details/85095266