在一个项目中导入多个不同tensorflow模型

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u010122972/article/details/79093479

刚开始直接采用调用一个模型的方法:
(1)定义网络
(2)新建sess:sess = tf.Session(config=config)
(3)定义saver:saver = tf.train.Saver()
(4)导入权重:saver.restore(sess, xxx)
但是,如果在一个项目中同时导入多个模型,会报错,应该是graph冲突,所以需要给每个模型单独新建graph:

g1 = tf.Graph()
isess = tf.Session(graph=g1)
with g1.as_default():
    (定义网络模型结构)
    isess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(isess, xxx)#xxx为ckpt路径
g2 = tf.Graph()
isess2 = tf.Session(graph=g2)
with g2.as_default():
    (定义网络模型结构)
    isess2.run(tf.global_variables_initializer())
    saver2 = tf.train.Saver()
    saver2.restore(isess2, xxx)
g3...
...

猜你喜欢

转载自blog.csdn.net/u010122972/article/details/79093479