python3 13.tensorflow中模型保存和恢复方法之checkpoint使用 学习笔记

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/mcyJacky/article/details/88516660

前言

     计算机视觉系列之学习笔记主要是本人进行学习人工智能(计算机视觉方向)的代码整理。本系列所有代码是用python3编写,在平台Anaconda中运行实现,在使用代码时,默认你已经安装相关的python库,这方面不做多余的说明。本系列所涉及的所有代码和资料可在我的github上下载到,gitbub地址:https://github.com/mcyJacky/DeepLearning-CV,如有问题,欢迎指出。

一、checkpoint模型保存

     Tensorflow训练后的模型可以保存checkpoint文件或pb文件,本篇主要讲checkpoint文件,checkpoint文件是结构与权重分离的四个文件,便于训练。其中生成的xxx.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,xxx.ckpt文件保存了TensorFlow程序中每一个变量的取值,checkpoint文件保存了一个目录下所有的模型文件列表。checkpoint文件的用途:可以记录模型训练的过程中的数据,并且可以实现在之前的训练基础上继续训练

     下面是一个之前篇章介绍的MNIST数据集简单分类的程序,在此基础上,我们使用checkpoint模型保存方式:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

# 每个批次64张照片
batch_size = 64
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784], name='x_input')
y = tf.placeholder(tf.float32, [None, 10], name='y_input')


# 神经网络结构
W = tf.Variable(tf.truncated_normal([784,10],stddev=0.1))
b = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(x,W) + b, name='output')

# 交叉熵代价函数
loss = tf.losses.softmax_cross_entropy(y, prediction)
# 使用Adam优化器
train_step = tf.train.AdamOptimizer(0.001).minimize(loss, name='train')

# 初始化变量
init = tf.global_variables_initializer()

#求准确率
correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')

# 定义一个Saver用于保存
Saver = tf.train.Saver()

with tf.Session() as sess:
    # 初始化变化
    sess.run(init)
    for epoch in range(11):
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
        #每个周期计算准确率
        acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
        print('iter: ' + str(epoch + 1) + " test accuracy: " + str(acc))
    #保存训练模型
    Saver.save(sess, 'test_ckpt_model/my_model.ckpt')

#训练计算结果:
# iter: 1 test accuracy: 0.9044
# iter: 2 test accuracy: 0.9154
# iter: 3 test accuracy: 0.9198
# iter: 4 test accuracy: 0.9238
# iter: 5 test accuracy: 0.9248
# iter: 6 test accuracy: 0.9263
# iter: 7 test accuracy: 0.9286
# iter: 8 test accuracy: 0.9273
# iter: 9 test accuracy: 0.9278
# iter: 10 test accuracy: 0.93
# iter: 11 test accuracy: 0.9281

     如上程序,我们在保存checkpoint模型时,会定义个Saver = tf.train.saver(),然后使用save()方法来进行指定路径下的模型保存。保存结果如下图1.1所示:

图1.1 checkpoint模型保存文件

二、checkpoint模型恢复方法之同时存在.ckpt和.meta文件

     当同时有模型文件.ckpt和.meta时,我们就可以直接从这两个文件中恢复模型结构和模型参数,以上述第一节模型保存结构为例,要恢复以上模型结构的具体使用方法如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 定义批次大小
batch_size = 64
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

with tf.Session() as sess:
    # 载入模型结构
    saver = tf.train.import_meta_graph('test_ckpt_model/my_model.ckpt.meta')
    # 载入模型参数
    saver.restore(sess, 'test_ckpt_model/my_model.ckpt')

    #根据tensor的名字获取对应的tensor
    # 之前保存模型的时候模型输出保存为output,":0"是保存模型参数时自动加上的,所以这里也要写上
    output = sess.graph.get_tensor_by_name('output:0')
    # 根据tensor的名字获取到对应的tensor
    # 之前保存模型的时候准确率计算保存为accuracy,":0"是保存模型参数时自动加上的,所以这里也要写上
    accuracy = sess.graph.get_tensor_by_name('accuracy:0')
    # 之前保存模型的时候模型训练保存为train,注意这里的train是operation不是tensor
    train_step = sess.graph.get_operation_by_name('train')

    # 把测试集喂到网络中计算准确率
    # x-input是模型数据的输入,":0"是保存模型参数时自动加上的,所以这里也要写上
    # y-input是模型标签的输入,":0"是保存模型参数时自动加上的,所以这里也要写上
    print(sess.run(accuracy, feed_dict={'x_input_1:0':mnist.test.images, 'y_input_1:0':mnist.test.labels}))
    
    # 在原来模型的基础上再训练11个周期
    for epoch in range(11):
        for batch in range(n_batch):
            batch_xs, batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={'x_input_1:0':batch_xs,'y_input_1:0':batch_ys})
        #准确率
        acc = sess.run(accuracy, feed_dict={'x_input_1:0':mnist.test.images, 'y_input_1:0':mnist.test.labels})
        print('iter: ' + str(epoch + 1) + ' test acc: ' + str(acc))
#训练结果如下:
# 0.9281
# iter: 1 test acc: 0.9303
# iter: 2 test acc: 0.9305
# iter: 3 test acc: 0.9301
# iter: 4 test acc: 0.9306
# iter: 5 test acc: 0.9316
# iter: 6 test acc: 0.9318
# iter: 7 test acc: 0.9326
# iter: 8 test acc: 0.9309
# iter: 9 test acc: 0.9319
# iter: 10 test acc: 0.9322
# iter: 11 test acc: 0.9317

     从这个程序可以知道,我们通过简单的import_meta_graph()方法可以导入网络结构,restore()方法恢复模型参数,get_tensor_by_name()方法得到计算图中的计算的tensor。在此基础上我们可以继续训练模型,这样就可以很好的避免重新模型的训练或者在模型训练过程中出现宕机而导致中途而废。

三、checkpoint模型恢复方法之无.meta文件

     在项目的实际应用过程中,我们往往能拿到.ckpt文件而没有.meta文件,这个时候我们需要恢复模型就需要自己搭建网络模型结构,以上述第一节模型保存结构为例,要在无.meta文件基础上恢复以上模型结构的具体使用方法如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
# 定义批次大小
batch_size = 64
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

# 创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
# 这里的模型参数需要跟之前训练好的模型参数一样
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W) + b)

# 计算准确率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

# 定义saver用来载入模型
# max_to_keep=5,在指定路径下最多保留5个模型,超过5个模型就会删除老的模型
saver = tf.train.Saver(max_to_keep=5)

# 定义会话
with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())
    # 计算测试集准确率
    print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
	
	# 载入训练好的参数
    saver.restore(sess,'test_ckpt_model/my_model.ckpt')
	# 在此计算准确率
    print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))

	# 在原来模型的基础上再训练11个周期
    for epoch in range(11):
        for batch in range(n_batch):
            # 获取一个批次的数据和标签
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            # 训练模型
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        # 计算测试集准确率
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        # 打印信息
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
        # 保存模型,global_step可以用来表示模型的训练次数或者训练周期数
        saver.save(sess, 'test_ckpt_model/my_model.ckpt', global_step=epoch)

#训练结果如下:
# 0.098
# INFO:tensorflow:Restoring parameters from test_ckpt_model/my_model.ckpt
# 0.9281
# Iter 0,Testing Accuracy 0.9301
# Iter 1,Testing Accuracy 0.9289
# Iter 2,Testing Accuracy 0.931
# Iter 3,Testing Accuracy 0.9314
# Iter 4,Testing Accuracy 0.9313
# Iter 5,Testing Accuracy 0.9321
# Iter 6,Testing Accuracy 0.9319
# Iter 7,Testing Accuracy 0.933
# Iter 8,Testing Accuracy 0.9315
# Iter 9,Testing Accuracy 0.9327
# Iter 10,Testing Accuracy 0.9333

     如上程序,我们在无.meta文件的基础上,要重新搭建网络结构,然后用restore()方法恢复模型数据,在此基础上继续训练模型,并继续保存模型。这边唯一的不同是在定义saver = tf.train.Saver(max_to_keep=5)时加入了max_to_keep参数,表示模型最多保存5个。模型保存图如下3.1所示:

图3.1 checkpoint模型保存文件

     
     
     
     
【参考】:
     1. 城市数据团课程《AI工程师》计算机视觉方向
     2. deeplearning.ai 吴恩达《深度学习工程师》
     3. 《机器学习》作者:周志华
     4. 《深度学习》作者:Ian Goodfellow


转载声明:
版权声明:非商用自由转载-保持署名-注明出处
署名 :mcyJacky
文章出处:https://blog.csdn.net/mcyJacky

猜你喜欢

转载自blog.csdn.net/mcyJacky/article/details/88516660