[tensorflow应用之路]模型的存储、读取和预测(c++/python)

之前的文章中讲了如何使用tensorflow源码编译一个c++版的动态库。同时留下了一个问题:能否在C++中读取预先训练好的模型呢?———答案是肯定的。
下面,就来一一介绍tensorflow模型在python中的存储和读取,在c++中的读取方式。为什么不讲如何用C++去存储一个模型呢?因为不建议大家用c++训练模型,其中的原因有三点:

其一,基本上99%的tensorflow神经网络都是用python写的,如果你想照抄一个网络,用python最方便。
其二,python有强大的第三方库,高级的语法特性,这些c++上实现需要花费巨大精力。
其三,tensorflow的C++接口支持的并不好。

一、python存储模型的方法

好了,进入正题,在python中如何存储tensorflow模型。

  1. tf.saved_model.builder(推荐
    tf.saved_model是tensorflow官网推荐的一个保存模型的方法,只要你输入保存模型的路径,就可以使用。基本使用方式如下:

    import tensorflow as tf
    
    input=...
    export_dir=...
    ...
    build net...
    ...
     #指定存储路径
    builder =tf.saved_model.builder.SavedModelBuilder(export_dir)
    with tf.Session() as sess:
        #下段话只能调用一次
        builder.add_meta_graph_and_variables(sess,['custom'])
    builder.save()

    其中,export_dir必须指定为一个不存在的路径,否则会报错。上面一段代码中,我们建立了一个名叫’custom’的网络,并将其保存在export_dir中,文件结构如下:

    |-saved_model.pb-|
    |-variables-|
        |-variables.data-00000-of-00001-|
        |-variables.index-|

    其实和下一个方法tf.save存出来的文件差不多。pb文件中是网络结构信息,index文件中是参数值。

  2. tf.train.saver
    tf.train.saver是1.3版本之前主要的模型存储方式,在新版本中也兼容,但已经不是最推荐的方式了。它的使用方式也很简单:

    import tensorflow as tf
    
    input=...
    model_path=...
    ...
    build net
    ...
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.save(sess,model_path)

    tf.save.saver产生的文件结构如下:

    |-saved_model.meta-|
    |-saved_model.data-00000-of-00001-|
    |-saved_model.index-|

    meta文件中存储网络结构,index文件中存储参数信息。

  3. tf.saved_model.builder和tf.train.saver方法比较

    tf.saved_model.builder方法:
    优点:
    1.只需要指定一个存储路径。存储、读取都很方便。
    2.可以存多段网络,参数可以复用。比如现在有一个GAN网络模型,用tf.saved_model.builder指定相应tag以后,可以同时存生成网络、鉴别网络和整个网络。之后读取时,只要读需要的那一部分即可,大大加快读取速度。提升内存利用率。
    3.在tensorflow推荐的estimate(一种更高级的机器学习API,以后填坑)流程中,扮演主要的模型存储方法。
    4.便于分布式读取及使用
    缺点:
    1.只能保存一次参数
    2.对于一个目录,只能导出一个模型。(但可以改变目录名)
    3.不灵活。
    4.速度慢。


    tf.train.saver方法:
    优点:
    1.灵活。可以指定保存模型的名称、后缀、多长时间保存一次、最多保存多少个模型等等。
    2.应用范围广。如果你使用tf.contrib.Slim库(类似tensorlayer的一种高级库)训练模型,那么只能用此方法保存模型。
    3.速度快。
    缺点:
    1.保存多个模型比较复杂。

二、python读取模型的方法

tensorflow读取模型的方法也很简单。我对应的介绍一下。

  1. tf.saved_model.loader
    如果你使用tf.saved_model.builder存储模型的话,那么可以使用tf.saved_model.loader读取模型。只输入一个模型存储的路径即可。简单的例子:

    export_dir = ...
    ...
    build net...
    ...
    with tf.Session(graph=tf.Graph()) as sess:
      tf.saved_model.loader.load(sess, ['custom'], export_dir)
      ...

    可以看到,该方式读取模型非常简单,只需要模型路径和网络标签即可,函数内部会自动加载网络模型和恢复参数。

  2. tf.train.saver.restore
    该方法需要先恢复网络结构(如果你有了定义网络的py文件,可以跳过此步,等价的),再读取参数。简单的例子:

    model_path=...
     #恢复网络结构
    saver = tf.train.import_meta_graph(model_path + '.meta')
    with tf.Session() as sess:
        #读取参数
        saver.restore(sess, model_path)
        graph = sess.graph
        input = graph.get_tensor_by_name('input:0')
        ...
        prediction...
        ...

pythond的模型存取方式就介绍到这里,更多有关tf.train.save和tf.saved_model的区别请点这里


c++读取模型的方式

此章将会辅助一些截图说明。原因是相对于python,tensorflow的c++接口的有点烂。一开始也许你会卡在某一步骤,但是耐心的一步步排查,终将能成功。

  1. LoadSavedModel(对应tf.save_model.builder方式)
    先上代码

     #include <string>
     #include <cc/saved_model/loader.h>
     #include <google/protobuf/message.h>
    
    tensorflow::Status LoadGraph(std::string modelDir, std::unique_ptr<tensorflow::Session>* sess) {
        //定义初始环境
        const std::string export_dir = modelDir;
        tensorflow::SessionOptions session_options;
        tensorflow::RunOptions run_options;
        tensorflow::SavedModelBundle bundle;
        tensorflow::Status status;
        constexpr char kSavedModelTagServe[] = "train";
        //存储模型
        status=LoadSavedModel(session_options, run_options, export_dir, { kSavedModelTagServe },&bundle);
        if (!status.ok()) {
            std::cerr << "Error reading graph definition from " + modelDir+ ": " + status.ToString() << std::endl;
            return status;
        }
        *sess =std::move(bundle.session);
        return status;
    };

    其中,modelDir是模型目录,sess是载入的图模型环境。运行完LoadSavedModel方法后,你得到的status状态应该是空的,bundle中应该已经有内容了:
    这里写图片描述
    如果发生错误,status里会有相应的问题描述,可以根据它尝试解决一下问题。

  2. LoadSavedModel(对应tf.save_model.builder方式)
    简单例子如下:

     #include <string>
     #include <cc/framework/scope.h>
     #include <core/public/session.h>
     #include <core/protobuf/meta_graph.pb.h>
     using namespace tensorflow;
    
    tensorflow::Status LoadGraph(std::string checkpointPath, std::unique_ptr<tensorflow::Session>* sess) {
        string metaGraphPath= checkpointPath + ".meta";
        if (*sess == nullptr) {
            (*sess).reset(tensorflow::NewSession(tensorflow::SessionOptions()));
        }
        Status status;
        auto scp = ::Scope::NewRootScope();
        // 读网络
        tensorflow::MetaGraphDef graph_def;
        status = ReadBinaryProto(Env::Default(), metaGraphPath, &graph_def);
        if (!status.ok()) {
            std::cerr << "Error reading graph definition from " + metaGraphPath+ ": " + status.ToString() << std::endl;
            return status;
        }
        // 将网络加入sess中
        status = (*sess)->Create(graph_def.graph_def());
        if (!status.ok()) {
            std::cerr << "Error creating graph: " + status.ToString() << std::endl;
        }
        // 读参数
        Tensor checkpointPathTensor(DT_STRING, TensorShape());
        checkpointPathTensor.scalar<std::string>()() = checkpointPath;
        status = (*sess)->Run(
        { { graph_def.saver_def().filename_tensor_name(), checkpointPathTensor }, },
        {},
        { graph_def.saver_def().restore_op_name() },
            nullptr);
        if (!status.ok()) {
            std::cerr << "Error loading checkpoint from " + checkpointPath + ": " + status.ToString() << std::endl;
        }
        return status;
    };

    其中,checkpointPath是模型路径,sess是载入的图模型环境。运行完后,你得到的status状态应该是空的,graph_def.meta_info_def中有值,如下:

    这里写图片描述
    如果错误,status会返回错误方法,可以根据描述修复问题。


好了,tensorflow模型存取的方法介绍到这里。接口中的一些参数我没有仔细讲,需要深入研究的童鞋可以去tensorflow官网看一下参数介绍。如果你掌握了本章所说的方法,那么基本上tensorflow的应用已经不成问题了。以后我会多讲一讲具体的tensorflow预测网络应用。

最后祝您身体健康,再见!

猜你喜欢

转载自blog.csdn.net/h8832077/article/details/79006116