c++中启动一个thrift服务加载tensorflow训练的模型

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/luoyexuge/article/details/81871430

前面有几篇博客已经说明如何编译tensorflow,供c++接口调用,前面博客也已经提到怎么利用thrift搭载一个服务,分为客户端和服务端,实际两个综合到一块来相对还是比较容易,下面看下,简单的实现,首先是tensorflow.thrift编写,写的相对比较简单:

service Serv{
     i32  getresult(),
}

然后用thrift -r -gen tensorflow.thrift生成代码既可以,把生成的代码拷贝到项目下面文件thrifttensorflow下面既可以,下面分别写个server和client客户端:

server部分:

#include <iostream>
#include "thrifttensorflow/Serv.h"
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/server/TSimpleServer.h>
#include <thrift/transport/TServerSocket.h>
#include <thrift/transport/TBufferTransports.h>
#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

using namespace tensorflow;

using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::server;
using namespace std;

class ServHandler : virtual public ServIf {
public:
    Session* session;
    Status status;
    GraphDef  graphDef;
    ServHandler(const string &path) {
        status=NewSession(SessionOptions(),&session);
        if(!status.ok()){
            cout<<status.ToString()<<"\n";
        }
        status =ReadBinaryProto(Env::Default(),path,&graphDef);
        if(!status.ok()){
            cout<<status.ToString()<<"\n";

        }
        status=session->Create(graphDef);
        if(!status.ok()){
            cout<<status.ToString()<<"\n";

        }
    }

    int32_t getresult() {

        printf("string=%s", "start....");
        vector<int> vec={7997, 1945, 8471, 14127, 17565, 7340, 20224, 17529, 3796, 16033, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
        int ndim=vec.size();
        Tensor x(tensorflow::DT_INT32, tensorflow::TensorShape({1, ndim})); // New Tensor shape [1, ndim]
        auto x_map = x.tensor<int, 2>();
        for (int j = 0; j < ndim; j++) {
            x_map(0, j) = vec[j];
        }
        std::vector<std::pair<string, tensorflow::Tensor>> inputs;
        inputs.push_back(std::pair<std::string, tensorflow::Tensor>("input_x", x));

        Tensor keep_prob(tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
        keep_prob.vec<float>()(0) = 1.0f;

        inputs.push_back(std::pair<std::string, tensorflow::Tensor>("keep_prob", keep_prob));
        Tensor tensor_out(tensorflow::DT_INT32, TensorShape({1,ndim}));
        std::vector<tensorflow::Tensor> outputs={{ tensor_out }};

        status = session->Run(inputs, {"crf_pred/ReverseSequence_1"}, {}, &outputs);
//        for(int i=0;i<40;++i) {
//            printf("%d %s", outputs[0].matrix<int>()(0,i)," ");
//        }
        return  outputs[0].matrix<int>()(0,0);
    }

};

int main(int argc, char **argv) {
    int port = 9090;

    if (argc<2){
        cout<<"请传入模型路径,用于加载模型";
        return 1;
    }else{
        modelpath=argv[1];
    }

    stdcxx::shared_ptr<ServHandler> handler(new ServHandler(modelpath));
    stdcxx::shared_ptr<TProcessor> processor(new ServProcessor(handler));
    stdcxx::shared_ptr<TServerTransport> serverTransport(new TServerSocket(port));
    stdcxx::shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
    stdcxx::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());

    TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
    server.serve();
    return 0;
}


client部分:

#include "thrifttensorflow/Serv.h"
#include <iostream>
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/server/TSimpleServer.h>
#include <thrift/transport/TSocket.h>
#include <thrift/transport/TBufferTransports.h>


using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::server;

using  namespace std;

int main(int argc, char **argv)
{
    stdcxx::shared_ptr<TSocket> socket(new TSocket("localhost", 9090));
    stdcxx::shared_ptr<TTransport> transport(new TBufferedTransport(socket));
    stdcxx::shared_ptr<TProtocol> protocol(new TBinaryProtocol(transport));


    transport->open();

    //调用server服务
    ServClient client(protocol);

    int result=client.getresult();

    transport->close();

    cout<<"请求结束,结果是:"<<result<<endl;

    return 0;
}


CMakeLists.txt部分:

cmake_minimum_required(VERSION 3.10)
project(credit_thrifttensorflow)



include_directories(/usr/local/include/thrift)
include_directories(/usr/local/include/boost)

link_directories(/usr/local/lib)


link_directories(/Users/zhoumeixu/Documents/tensorflow/bazel-bin/tensorflow)
include_directories(
        /Users/zhoumeixu/Documents/tensorflow
        /Users/zhoumeixu/Documents/tensorflow/bazel-genfiles
        /Users/zhoumeixu/Documents/tensorflow/bazel-bin/tensorflow
        /Users/zhoumeixu/Downloads/eigen3)

set(thrifttensorflow_SOURCES
        thrifttensorflow/Serv.cpp
        thrifttensorflow/tensorflow_constants.cpp
        thrifttensorflow/tensorflow_types.cpp
        server.cpp
        )

set(CMAKE_CXX_STANDARD 11)

# 生成静态库目标
add_library(thrifttensorflow STATIC ${thrifttensorflow_SOURCES})
target_link_libraries(thrifttensorflow thrift)

# 同下
add_executable(server server.cpp)
target_link_libraries(server thrifttensorflow thrift  tensorflow_cc tensorflow_framework)

# 生成 demo_client 可执行程序,要求链接 demo 静态库, thrift XX库
add_executable(client client.cpp)
target_link_libraries(client thrifttensorflow thrift )

cmake ..

make

既可以生成server、client执行文件

猜你喜欢

转载自blog.csdn.net/luoyexuge/article/details/81871430
今日推荐