TensorFlow模型参数的保存和加载(含演示代码)

版权声明:本文为博主原创文章,如需转载请注明出处。因博主水平有限,如有疏忽遗漏,敬请指出。 https://blog.csdn.net/ShadowN1ght/article/details/78598834

当我们通过TensorFlow构建了一个训练模型,譬如人脸识别或场景分类网络,并且找到合适的数据集,经过较长时间的训练后,识别率令人满意,这时候我们希望把训练结果保存下来,下次使用时可以直接调用,而不需要重新训练。这就涉及到一个如何保存和加载TensorFlow训练参数的问题。


为方便使用者保存训练结果,TensorFlow提供了tf.train.Saver模块用于保存当前会话中所有的变量值(Variables)。当构建网络模型时,需要为待保存的变量指定name属性,如下所示:

W = tf.Variable(tf.zeros([784, 10]),name="var_W")


当后面进行加载恢复操作(restore)时,只需要指定变量名,就可以直接获取到上一次训练保存的变量值,代码如下:

W = sess.graph.get_tensor_by_name("var_W:0")


当训练结束时,只需要执行如下语句,即可保存训练结果到指定目录:

saver = tf.train.Saver()

saver_path = saver.save(sess,"%smodel.ckpt"%(SAVER_DIR))


上述操作可总结为:

1. 构建网络时为Variables指定名字;

2. 在完成训练迭代之后,调用tf.train.Saver()的save()函数,保存训练结果;

3. 在进行识别任务时,调用sess.graph.get_tensor_by_name()获取上一次训练结果。

完整演示代码如下:

#!/usr/bin/python3.5
# -*- coding: utf-8 -*-

import os
import sys

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


SAVER_DIR = "train-saver/"


mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

x = tf.placeholder(tf.float32, [None, 784])
labels = tf.placeholder(tf.float32, [None, 10])

print ('本脚本须输入参数save或restore')
print ('如果当前目录下没有MNIST_data数据,可能需要花费几分钟等待mnist数据下载')
print ('如果下载缓慢,可以从百度网盘http://pan.baidu.com/s/1c2k3gkw直接下载,放到运行脚本同一目录下即可')



if __name__ =='__main__' and sys.argv[1]=='save':
    W = tf.Variable(tf.zeros([784, 10]), name="var_W")
    b = tf.Variable(tf.zeros([10]), name="var_b")

    # 构建网络op
    soft_result = tf.nn.softmax(tf.matmul(x, W) + b)

    cross_entropy = tf.reduce_mean(-tf.reduce_sum(labels * tf.log(soft_result), reduction_indices=[1]))

    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

    sess = tf.InteractiveSession()

    tf.global_variables_initializer().run()

    for i in range(1000):
        if i%10 == 0:
            print ("正在进行第 %d 次训练迭代......" % (i))
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, labels: batch_ys})

    # 保存训练结果
    if not os.path.exists(SAVER_DIR):
        print ('不存在训练数据保存目录,现在创建保存目录')
        os.makedirs(SAVER_DIR)
    # 初始化saver
    saver = tf.train.Saver()            
    saver_path = saver.save(sess, "%smodel.ckpt"%(SAVER_DIR))

    print ('mnist手写体数字训练结果已保存!')



if __name__ =='__main__' and sys.argv[1]=='restore':
    sess = tf.InteractiveSession()
    
    # 导入保存训练结果的文件
    saver = tf.train.import_meta_graph("%smodel.ckpt.meta"%(SAVER_DIR))
    model_file=tf.train.latest_checkpoint(SAVER_DIR)
    saver.restore(sess, model_file)

    # 通过指定变量名获取训练结果中的变量值
    W = sess.graph.get_tensor_by_name("var_W:0")
    b = sess.graph.get_tensor_by_name("var_b:0")

    # 执行识别
    soft_result = tf.nn.softmax(tf.matmul(x, W) + b)
    correct_prediction = tf.equal(tf.argmax(soft_result,1), tf.argmax(labels,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # 打印识别结果
    print("mnist手写体数字识别准确率为:%0.3f%%" % (sess.run(accuracy, feed_dict={x: mnist.test.images, labels: mnist.test.labels})*100))

运行上述代码时,需要注意的是,命令行启动脚本的参数不一样。假设脚本文件名为tensor-restore.py,则训练时启动命令为:

python tensor-restore.py save

识别时启动命令为:

python tensor-restore.py restore

当第一次运行该脚本的时候,如果当前目录没有mnist数据集,则会自动下载数据集,如果网络不稳定,那么下载过程会很缓慢。为方便使用,我把mnist数据集上传到百度网盘,链接地址如下:

http://pan.baidu.com/s/1c2k3gkw

下载后将整个MNIST_data文件夹放到脚本同一目录,运行时就不会触发下载了。


op的保存和加载方法可参考《TensorFlow模型op的保存和加载》一文。

猜你喜欢

转载自blog.csdn.net/ShadowN1ght/article/details/78598834