tensorflow保存、加载模型并预测数据

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_38358305/article/details/89042902

保存模型(ckpt)

仅需两行即可保存模型
saver = tf.train.Saver(tf.global_variables(), max_to_keep= 5)
#第二个参数填任意数字(用于区别各个保存的模型)
path = saver.save(sess, '../model/textCNN/model/my-model',global_step = currentStep)
wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

注意:(保存模型需要自己先建立路径文件夹)

if not os.path.exists('../model/textCNN/model'):
    os.makedirs('../model/textCNN/model')#makedirs可以建立多层文件夹

加载模型

(调用ckpt查看模型地址)

#只需替换ckpt地址即可
ckpt = tf.train.get_checkpoint_state(r'C:\Users\Jaykie\Desktop\textClassifier\model\textCNN\model' + '/')#ckpt地址
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
with tf.Session() as sess:
    saver.restore(sess, ckpt.model_checkpoint_path)
    graph = tf.get_default_graph()
wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

或者直接输入模型地址

new_saver = tf.train.import_meta_graph('%s.meta' % (parameters["mod_trained"]))
with tf.Session() as sess:
    new_saver.restore(sess, '%s' % (parameters["mod_trained"]))
wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

配置feed_dict和输出结果

#tensor_name_list = [tensor.name for tensor in graph.as_graph_def().node]  # 得到当前图中所有变量的名称
#print(tensor_name_list)

#根据需要配置变量
_inputX = graph.get_tensor_by_name('inputX:0')
_dropoutKeepProb = graph.get_tensor_by_name('dropoutKeepProb:0')
feed_dict = {
    _inputX: trainReviews,
    _dropoutKeepProb: 1
}
#根据需要配置输出
y = graph.get_tensor_by_name('output/binaryPreds:0')
#run
predict = sess.run([y],feed_dict)

变量名根据计算图中的各个占位符的名称,用graph.get_tensor_by_name导出,记得加:0或者[0]。

如果未定义变量名称,可以通过tensor_name_list = [tensor.name for tensor in gragh.as_graph_def().node]来得到变量名称的列表

猜你喜欢

转载自blog.csdn.net/qq_38358305/article/details/89042902
今日推荐