tf.train.import_meta_graph TensorFlow内存溢出问题

TensorFlow运行程序的时候发现内存占用非常恐怖,运行没多久内存一下子就达到了99%,经过查资料发现使用sess.graph.finalize() 方法可以锁住图,让其变为只读,运行过程中一旦新增节点就会报错。

先看看原来的有问题的代码

def restore_test_model(data):
    sess = tf.Session()
    sess.graph.finalize() #锁住图,不允许新增节点
    model_file = tf.train.latest_checkpoint('ckpt/')
    saver = tf.train.import_meta_graph(model_file+'.meta') #报错
    saver.restore(sess, model_file)
    graph = tf.get_default_graph()
    out = graph.get_tensor_by_name("out:0")
    X = graph.get_tensor_by_name("input_data/X:0")
    y_num = sess.run(out, feed_dict={X: data[:, 1:]})
    sess.close()
    return y_num

代码在运行到saver = tf.train.import_meta_graph(model_file+’.meta’) 这一句的时候报错了,说明在这里新增了节点。

我自己的解决办法是把模型加载的动作写到with语句里面。

graph = tf.Graph()
    with graph.as_default():
        model_file = tf.train.latest_checkpoint('ckpt/')
        saver = tf.train.import_meta_graph(str(model_file) + '.meta')
        with tf.Session() as sess:
            saver.restore(sess, model_file)
出现内存溢出的问题可以尝试sess.graph.finalize()来排除错误

参考资料:https://blog.csdn.net/ferriswym/article/details/77996555

发布了22 篇原创文章 · 获赞 15 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_28566521/article/details/88984556