tensorflow 多个模型测试阶段速度越来越慢问题的解决方法

版权声明:原创博客未经允许请勿转载! https://blog.csdn.net/holmes_MX/article/details/82659869

0. 写作目的

好记性不如烂笔头。

1. 问题描述

tensorflow中多个模型在测试阶段,出现测试速度越来越慢的情况,通过查阅资料发现,由于tensorflow的图是静态图,但是如果直接加在不同的图(即不同的模型),应该都会存在内存中,因此造成了测试速度越来越慢,甚至导致机器卡顿(博主在测试100个模型时,一般测试20个模型左右出现卡顿),因此有必要探究更快的测试速度方法。

2. 方法的解决

1) 如果采用对于每一个模型均在一个tf.Session()中,会报错。

该中方式的代码大致为:

for modelNumber in range(modelTotalNumber):
    with tf.Session() as sess:
        model_file = tf.train.latest_checkpoint( modelDir )
        saver = tf.train.import( model_file + '.meta' ) ## load the Graph without weights
        saver.restore(sess, 
            tf.train.latest_checkpoint( modelDir ))
        

        ###XXXXX predict code

2) 如果在将多个模型放在同一个tf.Session()中,测试时会出现越来越慢的情况

 该方式的代码大致为:

with tf.Session() as sess:
    for modelNumber in range(modelTotalNumber):   
        model_file = tf.train.latest_checkpoint( modelDir )
        saver = tf.train.import( model_file + '.meta' ) ## load the Graph without weights
        saver.restore(sess, 
            tf.train.latest_checkpoint( modelDir ))
        

        ###XXXXX predict code

3) 解决多个模型测试时越来越慢的问题

需要在每次加载模型前,对模型默认模型进行重置,类似于将原模型“丢弃”

该方式的代码大致为:

for modelNumber in range(modelTotalNumber):
    tf.reset_default_graph()
    with tf.Session() as sess:
        model_file = tf.train.latest_checkpoint( modelDir )
        saver = tf.train.import( model_file + '.meta' ) ## load the Graph without weights
        saver.restore(sess, 
            tf.train.latest_checkpoint( modelDir ))
        

        ###XXXXX predict code
   

2. 总结

在多个模型测试时,出现测试速度越来越慢情况,采用tf.reset_default_graph()类解决,

在多个模型训练阶段,暂时为未发现越来越慢情况。

[Reference]

[1] https://zhuanlan.zhihu.com/p/31619020

[2] https://www.tensorflow.org/programmers_guide/graphs?hl=zh-cn

猜你喜欢

转载自blog.csdn.net/holmes_MX/article/details/82659869