Tensorflow中因未释放计算图而导致显存爆炸运行越来越慢的问题
问题描述
- 在一次模型修改中,因为需要在每运行一个batch时,调用另外一个模型计算相关的embedding导入进来当前模型,所以导致了会出现类似以下的调用方式:
for batch in batches:
with tf.Session() as sess:
tfops = tf add Ops ...
sess.run(tf.ops)
- 问题就出在了,虽然with sess会在每次调用后关闭session,但是其实没有释放掉计算图Graph,导致每次调用都会重新构建一个计算图,并存放在显存中。Tensorflow框架还会去维护这些建立好的计算图节点,所以导致训练速度越来越慢。
- 归根到底还是对Tensorflow的运行机制掌握不够熟练,在看了tensorflow sess.run()越来越慢的原因分析及其解决方法 - 知乎这篇博文后才恍然大悟,于是也参考链接里的方法解决了问题。
解决方法
- 在每次需要将动态建立的计算图销毁时,使用finalize()方法将其销毁,便不再占用显存。
for batch in batches:
tf.reset_default_graph() # 初始化计算图
with tf.Session() as sess:
tfops = tf add Ops ...
sess.run(tf.ops)
tf.get_default_graph().finalize() # 销毁计算图