版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u010122972/article/details/79093479
刚开始直接采用调用一个模型的方法:
(1)定义网络
(2)新建sess:sess = tf.Session(config=config)
(3)定义saver:saver = tf.train.Saver()
(4)导入权重:saver.restore(sess, xxx)
但是,如果在一个项目中同时导入多个模型,会报错,应该是graph冲突,所以需要给每个模型单独新建graph:
g1 = tf.Graph()
isess = tf.Session(graph=g1)
with g1.as_default():
(定义网络模型结构)
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, xxx)#xxx为ckpt路径
g2 = tf.Graph()
isess2 = tf.Session(graph=g2)
with g2.as_default():
(定义网络模型结构)
isess2.run(tf.global_variables_initializer())
saver2 = tf.train.Saver()
saver2.restore(isess2, xxx)
g3...
...