如何使用tensorflow加载keras训练好的模型

感谢

How to convert your Keras models to Tensorflow

前言

最近实验室碰到一个奇怪的需求,大家分别构建不同的NLP模型,最后需要进行整合,可是由于有的同学使用的是keras,有的同学喜欢使用TensorFlow,这样导致在构建接口时无法统一不同模型load的方式,每一个模型单独使用一种load的方式的话导致了很多重复开发,效率不高的同时也对项目的可扩展性造成了巨大的破坏。于是需要一种能够统一TensorFlow和keras模型的load过程的方法。

正文

1.构建keras模型
首先假设我们build了一个非常简单的keras模型,如下所示:

x = np.vstack((np.random.rand(1000,10),-np.random.rand(1000,10)))
y = np.vstack((np.ones((1000,1)),np.zeros((1000,1))))
print(x.shape)
print(y.shape)

model = Sequential()
model.add(Dense(units = 32, input_shape=(10,), activation ='relu'))
model.add(Dense(units = 16, activation ='relu'))
model.add(Dense(units = 1, activation ='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['binary_accuracy'])
model.fit(x = x, y=y, epochs = 2, validation_split=0.2) 

2.将keras模型保存为Protocol Buffers的格式
由于TensorFlow是支持将模型保存为Protocol Buffers(.pb)格式的,如果我们有一种方法能将keras模型保存为(.pb)格式的话,那我们的问题就解决了。可是天不遂人愿,keras没有直接提供这样一个将模型保存为(.pb)格式的方法,所以我们必须自己实现这样一个方法,如果你看过keras的源码的话,你会发现keras backend提供了一个get_session()的函数(只有基于TensorFlow的backend有),该函数会返回一个TensorFlow Session,这样一来我们就另辟蹊径,使用这个Session来保存keras模型,而不使用keras已经提供的保存模型的函数,方法如下:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    将输入的Session保存为静态的计算图结构.
    创建一个新的计算图,其中的节点以及权重和输入的Session相同. 新的计算图会将输入Session中不参与计算的部分删除。
    @param session 需要被保存的Session.
    @param keep_var_names 一个记录了需要被保存的变量名的list,若为None则默认保存所有的变量.
    @param output_names 计算图相关输出的name list.
    @param clear_devices 若为True的话会删除不参与计算的部分,这样更利于移植,否则可能移植失败
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph

我们通过如下方法调用上述函数来保存模型:

from keras import backend as K
frozen_graph = freeze_session(K.get_session(),
                              output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, wkdir, pb_filename, as_text=False)

3.在TensorFlow中载入保存的模型
载入保存模型的例子如下:

from tensorflow.python.platform import gfile
with tf.Session() as sess:
    # 从(.pb)文件中载入模型
    with gfile.FastGFile(wkdir+'/'+pb_filename,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        g_in = tf.import_graph_def(graph_def)

猜你喜欢

转载自blog.csdn.net/u014475479/article/details/84709301