python3 14.tensorflow中模型保存和恢复方法之protocol_buffer模式 学习笔记

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/mcyJacky/article/details/88540351

前言

     计算机视觉系列之学习笔记主要是本人进行学习人工智能(计算机视觉方向)的代码整理。本系列所有代码是用python3编写,在平台Anaconda中运行实现,在使用代码时,默认你已经安装相关的python库,这方面不做多余的说明。本系列所涉及的所有代码和资料可在我的github上下载到,gitbub地址:https://github.com/mcyJacky/DeepLearning-CV,如有问题,欢迎指出。

一、protocol_buffer(pb)模型保存

     Tensorflow在训练模型过程中进行模型保存除了前面一篇介绍的checkpoint方法外,还有一种重要的方式就是将模型保存为.pb文件。它与checkpoint文件的区别是:该模型比checkpoint文件占用空间少,且保存的是常量,可以用来做预测,但不能用于继续训练

     下面就是建立一个简单的MNIST数据集分类的神经网络,然后使用protocol_buffer的模型保存方式:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 每个批次64张照片
batch_size = 64
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
# 给模型数据输入的入口起名为x-input
x = tf.placeholder(tf.float32,[None,784], name='x-input')
# 给模型标签输入的入口起名为y-input
y = tf.placeholder(tf.float32,[None,10], name='y-input')

# 创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.truncated_normal([784,10],stddev=0.1))
b = tf.Variable(tf.zeros([10])+0.1)
# 给模型输出起名为output
prediction = tf.nn.softmax(tf.matmul(x,W) + b, name='output')

# 交叉熵代价函数
loss = tf.losses.softmax_cross_entropy(y, prediction)
# 使用Adam优化器,给优化器operation起名为train
train_step = tf.train.AdamOptimizer(0.001).minimize(loss, name='train')

# 初始化变量
init = tf.global_variables_initializer()

# 求准确率
# tf.argmax(y,1)中的1表示取y中第1个维度中最大值所在的位置
# tf.equal表示比较两个值是否相等,相等返回True,不相等返回False
# 最后correct_prediction是一个布尔型的列表
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
# tf.cast表示数据格式转换,把布尔型转为float类型,True变成1.0,False变成0.0
# tf.reduce_mean求平均值
# 最后accuracy为准确率
# 给准确率tensor起名为accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')

# 创建会话
with tf.Session() as sess:
    # 变量初始化
    sess.run(init)
    # 运行11个周期
    for epoch in range(11):
        for batch in range(n_batch):
            # 获取一个批次的数据和标签
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            # 喂到模型中做训练
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        # 每个周期计算一次测试集准确率
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        # 打印信息
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))

    # 保存模型参数和结构,把变量变成常量
    # output_node_names设置可以输出的tensor
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output', 'accuracy'])	
	# 保存模型到目录下的test_pb-model文件夹中
	 with tf.gfile.FastGFile('test_pb-model/my_model.pb', mode='wb') as f:
	 	f.write(output_graph_def.SerializeToString())

#输出结果:
# Iter 0,Testing Accuracy 0.9026
# Iter 1,Testing Accuracy 0.9132
# Iter 2,Testing Accuracy 0.9171
# Iter 3,Testing Accuracy 0.9223
# Iter 4,Testing Accuracy 0.9235
# Iter 5,Testing Accuracy 0.9251
# Iter 6,Testing Accuracy 0.9266
# Iter 7,Testing Accuracy 0.9283
# Iter 8,Testing Accuracy 0.929
# Iter 9,Testing Accuracy 0.9295
# Iter 10,Testing Accuracy 0.9311
# INFO:tensorflow:Froze 2 variables.
# INFO:tensorflow:Converted 2 variables to const ops.

     如上程序,通过convert_variables_to_constants()方法,将需要保存的模型结点变量变成常量,再写入到xxx.pb文件中。生成的pb文件如下图1.1所示:

图1.1 pb文件模型

二、pb数据恢复

     在实际将模型用于预测时,我们还需要将pb数据进行恢复,下面将对上述程序保存的pb模型进行恢复:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

# 载入模型
with tf.gfile.FastGFile('test_pb-model/my_model.pb', 'rb') as f:
    # 创建一个序列图
    graph_def = tf.GraphDef()
    # 把模型文件载入到图中
    graph_def.ParseFromString(f.read())
    # 载入图到当前环境中
    tf.import_graph_def(graph_def, name='')

    
# 定义会话
with tf.Session() as sess:
    # 根据tensor的名字获取到对应的tensor
    # 之前保存模型的时候模型输出保存为output,":0"是保存模型参数时自动加上的,所以这里也要写上
    output = sess.graph.get_tensor_by_name('output:0')
    # 根据tensor的名字获取到对应的tensor
    # 之前保存模型的时候准确率计算保存为accuracy,":0"是保存模型参数时自动加上的,所以这里也要写上
    accuracy = sess.graph.get_tensor_by_name('accuracy:0')
    # 预测准确率
    print(sess.run(accuracy, feed_dict={'x-input:0':mnist.test.images,'y-input:0':mnist.test.labels}))

#输出结果:
#0.1051

     
     
     
     
【参考】:
     1. 城市数据团课程《AI工程师》计算机视觉方向
     2. deeplearning.ai 吴恩达《深度学习工程师》
     3. 《机器学习》作者:周志华
     4. 《深度学习》作者:Ian Goodfellow


转载声明:
版权声明:非商用自由转载-保持署名-注明出处
署名 :mcyJacky
文章出处:https://blog.csdn.net/mcyJacky

猜你喜欢

转载自blog.csdn.net/mcyJacky/article/details/88540351