keras和Tensorflow同时加载多个模型,以及与keras模型混用

可以参考https://stackoverflow.com/questions/51127344/tensor-is-not-an-element-of-this-graph-deploying-keras-model?r=SearchResults

Tensorflow同时加载使用多个模型(keras同理,只要是tf的后端)

Tensorflow,所有操作对象都包装在相应的session中,所以想要使用不同的模型就要将这些模型加载到不同session中,并且声明使用的时候申请是哪个session,从而避免由于session和想使用的模型不匹配导致错误,而使用多个graph就需要为每个graph使用不同的session,但是每个graph也可以在多个session中使用,这个时候就需要在每个session中使用的时候明确使用的graph。

g1 = tf.Graph() # 加载到Session 1的graph
g2 = tf.Graph() # 加载到Session 2的graph
 
sess1 = tf.Session(graph=g1) # Session1
sess2 = tf.Session(graph=g2) # Session2
 
# 加载第一个模型
with sess1.as_default(): 
    with g1.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(“model1/save/path”)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)
# 加载第二个模型
with sess2.as_default():  # 1
    with g2.as_default():  
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(“model2/save/path”)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)
 
...
 
# 使用的时候
with sess1.as_default():
    with sess1.graph.as_default():  # 2
        ...
 
with sess2.as_default():
    with sess2.graph.as_default():
        ...
 
# 关闭sess
sess1.close()
sess2.close()

在使用as_default使session在离开的时候并不关闭,在后面可以继续使用直到手动关闭,由于有多个graph,所以sess.graph与tf.get_default_value的值是不相等的,因此在进入sess的时候必须sess.graph.as_default()明确什么sess.graph为当前默认graph,否则会报错

不同框架的模型在加载的时候可能导致底层的cuDNN分配出问题从而报错,这种一般可以尝试通过模型的加载顺序而解决。

参考:https://www.cnblogs.com/arkenstone/p/7016481.html

https://blog.csdn.net/jmh1996/article/details/78793650/

https://blog.csdn.net/seniusen/article/details/82926850

keras和tensorflow多模型线上混用模型需要注意的

如果用tensorflow写的model,一般来说每个model都有自己的session和graph

但是在keras,会经常忽略掉session和graph,这时候需要添加session和好几个地方加with graph,伪代码如下:

seg_graph = tf.Graph()
sess = tf.Session(graph=seg_graph)
K.set_session(sess)
 
#保证代码
with seg_graph.as_default():
 
     self.keras_model = self.build(mode=mode, config=config)
     #上面一行代码会调用KM.Model
     
     以及这类函数
      topology.load_weights_fromXXX()
     
     以及predict函数

线上环境就这三次需要增加with seg_graph.as_default()

参考:https://blog.csdn.net/dayuqi/article/details/85295070

猜你喜欢

转载自blog.csdn.net/u013066730/article/details/107857785
今日推荐