版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/luoyexuge/article/details/81871430
前面有几篇博客已经说明如何编译tensorflow,供c++接口调用,前面博客也已经提到怎么利用thrift搭载一个服务,分为客户端和服务端,实际两个综合到一块来相对还是比较容易,下面看下,简单的实现,首先是tensorflow.thrift编写,写的相对比较简单:
service Serv{
i32 getresult(),
}
然后用thrift -r -gen tensorflow.thrift生成代码既可以,把生成的代码拷贝到项目下面文件thrifttensorflow下面既可以,下面分别写个server和client客户端:
server部分:
#include <iostream>
#include "thrifttensorflow/Serv.h"
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/server/TSimpleServer.h>
#include <thrift/transport/TServerSocket.h>
#include <thrift/transport/TBufferTransports.h>
#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
using namespace tensorflow;
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::server;
using namespace std;
class ServHandler : virtual public ServIf {
public:
Session* session;
Status status;
GraphDef graphDef;
ServHandler(const string &path) {
status=NewSession(SessionOptions(),&session);
if(!status.ok()){
cout<<status.ToString()<<"\n";
}
status =ReadBinaryProto(Env::Default(),path,&graphDef);
if(!status.ok()){
cout<<status.ToString()<<"\n";
}
status=session->Create(graphDef);
if(!status.ok()){
cout<<status.ToString()<<"\n";
}
}
int32_t getresult() {
printf("string=%s", "start....");
vector<int> vec={7997, 1945, 8471, 14127, 17565, 7340, 20224, 17529, 3796, 16033, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int ndim=vec.size();
Tensor x(tensorflow::DT_INT32, tensorflow::TensorShape({1, ndim})); // New Tensor shape [1, ndim]
auto x_map = x.tensor<int, 2>();
for (int j = 0; j < ndim; j++) {
x_map(0, j) = vec[j];
}
std::vector<std::pair<string, tensorflow::Tensor>> inputs;
inputs.push_back(std::pair<std::string, tensorflow::Tensor>("input_x", x));
Tensor keep_prob(tensorflow::DT_FLOAT, tensorflow::TensorShape({1}));
keep_prob.vec<float>()(0) = 1.0f;
inputs.push_back(std::pair<std::string, tensorflow::Tensor>("keep_prob", keep_prob));
Tensor tensor_out(tensorflow::DT_INT32, TensorShape({1,ndim}));
std::vector<tensorflow::Tensor> outputs={{ tensor_out }};
status = session->Run(inputs, {"crf_pred/ReverseSequence_1"}, {}, &outputs);
// for(int i=0;i<40;++i) {
// printf("%d %s", outputs[0].matrix<int>()(0,i)," ");
// }
return outputs[0].matrix<int>()(0,0);
}
};
int main(int argc, char **argv) {
int port = 9090;
if (argc<2){
cout<<"请传入模型路径,用于加载模型";
return 1;
}else{
modelpath=argv[1];
}
stdcxx::shared_ptr<ServHandler> handler(new ServHandler(modelpath));
stdcxx::shared_ptr<TProcessor> processor(new ServProcessor(handler));
stdcxx::shared_ptr<TServerTransport> serverTransport(new TServerSocket(port));
stdcxx::shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
stdcxx::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
server.serve();
return 0;
}
client部分:
#include "thrifttensorflow/Serv.h"
#include <iostream>
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/server/TSimpleServer.h>
#include <thrift/transport/TSocket.h>
#include <thrift/transport/TBufferTransports.h>
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::server;
using namespace std;
int main(int argc, char **argv)
{
stdcxx::shared_ptr<TSocket> socket(new TSocket("localhost", 9090));
stdcxx::shared_ptr<TTransport> transport(new TBufferedTransport(socket));
stdcxx::shared_ptr<TProtocol> protocol(new TBinaryProtocol(transport));
transport->open();
//调用server服务
ServClient client(protocol);
int result=client.getresult();
transport->close();
cout<<"请求结束,结果是:"<<result<<endl;
return 0;
}
CMakeLists.txt部分:
cmake_minimum_required(VERSION 3.10)
project(credit_thrifttensorflow)
include_directories(/usr/local/include/thrift)
include_directories(/usr/local/include/boost)
link_directories(/usr/local/lib)
link_directories(/Users/zhoumeixu/Documents/tensorflow/bazel-bin/tensorflow)
include_directories(
/Users/zhoumeixu/Documents/tensorflow
/Users/zhoumeixu/Documents/tensorflow/bazel-genfiles
/Users/zhoumeixu/Documents/tensorflow/bazel-bin/tensorflow
/Users/zhoumeixu/Downloads/eigen3)
set(thrifttensorflow_SOURCES
thrifttensorflow/Serv.cpp
thrifttensorflow/tensorflow_constants.cpp
thrifttensorflow/tensorflow_types.cpp
server.cpp
)
set(CMAKE_CXX_STANDARD 11)
# 生成静态库目标
add_library(thrifttensorflow STATIC ${thrifttensorflow_SOURCES})
target_link_libraries(thrifttensorflow thrift)
# 同下
add_executable(server server.cpp)
target_link_libraries(server thrifttensorflow thrift tensorflow_cc tensorflow_framework)
# 生成 demo_client 可执行程序,要求链接 demo 静态库, thrift XX库
add_executable(client client.cpp)
target_link_libraries(client thrifttensorflow thrift )
cmake ..
make
既可以生成server、client执行文件