Tensorflow中因未释放计算图而导致显存爆炸运行越来越慢的问题

Tensorflow中因未释放计算图而导致显存爆炸运行越来越慢的问题


问题描述

  • 在一次模型修改中,因为需要在每运行一个batch时,调用另外一个模型计算相关的embedding导入进来当前模型,所以导致了会出现类似以下的调用方式:
    for batch in batches:
        with tf.Session() as sess:
            tfops = tf add Ops ...
            sess.run(tf.ops)
  • 问题就出在了,虽然with sess会在每次调用后关闭session,但是其实没有释放掉计算图Graph,导致每次调用都会重新构建一个计算图,并存放在显存中。Tensorflow框架还会去维护这些建立好的计算图节点,所以导致训练速度越来越慢。
  • 归根到底还是对Tensorflow的运行机制掌握不够熟练,在看了tensorflow sess.run()越来越慢的原因分析及其解决方法 - 知乎这篇博文后才恍然大悟,于是也参考链接里的方法解决了问题。

解决方法

  • 在每次需要将动态建立的计算图销毁时,使用finalize()方法将其销毁,便不再占用显存。
    for batch in batches:
        tf.reset_default_graph()   # 初始化计算图
        with tf.Session() as sess:
            tfops = tf add Ops ...
            sess.run(tf.ops)
        tf.get_default_graph().finalize()  # 销毁计算图
原创文章 25 获赞 34 访问量 7万+

猜你喜欢

转载自blog.csdn.net/HOMEGREAT/article/details/99687233