Tensorflow模型的保存、载入、使用(多方法,附代码)

相信看到这篇文章的时候,大家都已经看完了之前的mnist手写实例了,如果没有看,没关系,点击(Tensorflow实战mnist手写识别

通过这篇文章,你可以学到:

  • 两种方法模型的保存
  • 两种方法对于模型的载入和使用

方法一:tf.train.Saver()

涉及到的方法:

保存:
tf.add_to_collection('logits',y)
saver = tf.compat.v1.train.Saver()
saver.save(sess,"./model/model",global_step=200)

载入:
input_model=tf.train.import_meta_graph('./model/model-200.meta')
input_model.restore(sess,'./model/model-200')
sess.graph.get_operation_by_name('input_x').outputs[0]
tf.get_collection('logits')[0]

在这里插入图片描述
add_to_collection()
这个方法是将某一变量添加到集合中,可以在之后模型调用的时候再拿出来用。

saver.save(a,b,c)
< a参数是会话 也就是tf.compat.v1.Session()
< b参数是路径和名称,如图所示,model文件下,文件开头是model。
< c参数也可以当作给文件命名,如果你想在程序训练的第200步进行模型存储,可以给它赋值200 此时,文件名后面就会多一个-200

tf.train.import_meta_graph(a)
input_model.restore(c,d’)
这两个方法就是对模型的导入和恢复。
< a参数为之前保存的后缀名是.meta的文件地址
< c参数是会话sess,即tf.compat.v1.Session()
< d 参数是文件名,不带后缀 比如上图的model-200

下面简单说一下图中四个文件的用处
checkpoint 最新的文件保存地址记录
model-200.data-00000-of-00001、保存了所有的训练变量
model-200.meta 保存了整个模型的图(Graph)


tf.add_to_collection(‘logits’,y)

tf.get_collection(‘logits’)[0]

前者是在模型中保存变量,后者是在加载完模型中,获得模型中的变量

方法二:tf.saved_model.builder.SavedModelBuilder()

保存:
#builder = tf.saved_model.builder.SavedModelBuilder('check_path_mnist')#保存模型方法二
#builder.add_meta_graph_and_variables(sess,['predict_mnist'])#保存模型方法二
#builder.save()#保存模型方法二

载入:
tf.saved_model.loader.load(sess,['predict_mnist'],"check_path_mnist")

在这里插入图片描述
这里面的参数我就不解释了,跟方法一大同小异,大家对照了看看就行了,有问题可以底部留言。
具体代码大家可以去Github下载 地址


参考博客:
Tensorflow模型载入

发布了48 篇原创文章 · 获赞 34 · 访问量 23万+

猜你喜欢

转载自blog.csdn.net/lzx159951/article/details/100604989
今日推荐