版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_38358305/article/details/89042902
保存模型(ckpt)
仅需两行即可保存模型
saver = tf.train.Saver(tf.global_variables(), max_to_keep= 5)
#第二个参数填任意数字(用于区别各个保存的模型)
path = saver.save(sess, '../model/textCNN/model/my-model',global_step = currentStep)
注意:(保存模型需要自己先建立路径文件夹)
if not os.path.exists('../model/textCNN/model'):
os.makedirs('../model/textCNN/model')#makedirs可以建立多层文件夹
加载模型
(调用ckpt查看模型地址)
#只需替换ckpt地址即可
ckpt = tf.train.get_checkpoint_state(r'C:\Users\Jaykie\Desktop\textClassifier\model\textCNN\model' + '/')#ckpt地址
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
graph = tf.get_default_graph()
或者直接输入模型地址
new_saver = tf.train.import_meta_graph('%s.meta' % (parameters["mod_trained"]))
with tf.Session() as sess:
new_saver.restore(sess, '%s' % (parameters["mod_trained"]))
配置feed_dict和输出结果
#tensor_name_list = [tensor.name for tensor in graph.as_graph_def().node] # 得到当前图中所有变量的名称
#print(tensor_name_list)
#根据需要配置变量
_inputX = graph.get_tensor_by_name('inputX:0')
_dropoutKeepProb = graph.get_tensor_by_name('dropoutKeepProb:0')
feed_dict = {
_inputX: trainReviews,
_dropoutKeepProb: 1
}
#根据需要配置输出
y = graph.get_tensor_by_name('output/binaryPreds:0')
#run
predict = sess.run([y],feed_dict)
变量名根据计算图中的各个占位符的名称,用graph.get_tensor_by_name导出,记得加:0或者[0]。
如果未定义变量名称,可以通过tensor_name_list = [tensor.name for tensor in gragh.as_graph_def().node]来得到变量名称的列表