Tensorflow笔记(二)--模型的加载和保存

Tensorflow笔记系列:

一、flag函数
二、模型的加载和保存

1. Tensorflow模型文件

我们在checkpoint_dir目录下保存的文件结构如下:

|--checkpoint_dir
|    |--checkpoint
|    |--MyModel.meta
|    |--MyModel.data-00000-of-00001
|    |--MyModel.index

1.1 meta文件

MyModel.meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。

1.2 ckpt文件

ckpt文件是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之前,保存在.ckpt文件中。0.11后,通过两个文件保存,如:

MyModel.data-00000-of-00001
MyModel.index

1.3 checkpoint文件

我们还可以看,checkpoint_dir目录下还有checkpoint文件,该文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model

2 保存Tensorflow模型

tensorflow 提供了tf.train.Saver类来保存模型,值得注意的是,在tensorflow中,变量是存在于Session环境中,也就是说,只有在Session环境下才会存有变量值,因此,保存模型时需要传入session:

saver = tf.train.Saver()
saver.save(sess,"./checkpoint_dir/MyModel")
看一个简单例子:

import tensorflow as tf
 
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './checkpoint_dir/MyModel')

执行后,在checkpoint_dir目录下创建模型文件如下:

checkpoint
MyModel.data-00000-of-00001
MyModel.index
MyModel.meta

另外,如果想要在1000次迭代后,再保存模型,只需设置global_step参数即可

保存的模型文件名称会在后面加-1000,如下:

checkpoint
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-1000.meta

3 导入训练好的模型

在第1小节中我们介绍过,tensorflow将图和变量数据分开保存为不同的文件。因此,在导入模型时,也要分为2步:构造网络图和加载参数

3.1 构造网络图

一个比较笨的方法是,手敲代码,实现跟模型一模一样的图结构。其实,我们既然已经保存了图,那就没必要在去手写一次图结构代码。

saver=tf.train.import_meta_graph(’./checkpoint_dir/MyModel-1000.meta’)
上面一行代码,就把图加载进来了

3.2 加载参数

仅仅有图并没有用,更重要的是,我们需要前面训练好的模型参数(即weights、biases等),本文第2节提到过,变量值需要依赖于Session,因此在加载参数时,先要构造好Session:

import tensorflow as tf
with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))
此时,W1和W2加载进了图,并且可以被访问:

import tensorflow as tf
with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./checkpoint_dir'))
    print(sess.run('w1:0'))

##Model has been restored. Above statement will print the saved value
执行后,打印如下:

[ 0.51480412 -0.56989086]

发布了68 篇原创文章 · 获赞 31 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/qq_35307005/article/details/90166465
今日推荐