tensorflow c++ API加载.pb模型文件并预测图片

tensorflow  python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测

  1 #include <iostream>
  2 #include <map>
  3 
  4 #include "tensorflow/cc/ops/const_op.h"
  5 #include "tensorflow/cc/ops/image_ops.h"
  6 #include "tensorflow/cc/ops/standard_ops.h"
  7 #include "tensorflow/core/framework/graph.pb.h"
  8 #include "tensorflow/core/framework/tensor.h"
  9 #include "tensorflow/core/graph/default_device.h"
 10 #include "tensorflow/core/graph/graph_def_builder.h"
 11 #include "tensorflow/core/lib/core/errors.h"
 12 #include "tensorflow/core/lib/core/stringpiece.h"
 13 #include "tensorflow/core/lib/core/threadpool.h"
 14 #include "tensorflow/core/lib/io/path.h"
 15 #include "tensorflow/core/lib/strings/stringprintf.h"
 16 #include "tensorflow/core/platform/init_main.h"
 17 #include "tensorflow/core/platform/logging.h"
 18 #include "tensorflow/core/platform/types.h"
 19 #include "tensorflow/core/public/session.h"
 20 #include "tensorflow/core/util/command_line_flags.h"
 21 
 22 using namespace std ;
 23 using namespace tensorflow;
 24 using tensorflow::Flag;
 25 using tensorflow::Tensor;
 26 using tensorflow::Status;
 27 using tensorflow::string;
 28 using tensorflow::int32;
 29 
 30 map<int,string> int2char;
 31 string s = "KDA0123456789 ";
 32 for(int i=0;i<s.size();i++){
 33     int2char[i]=s[i];
 34 }
 35 
 36 //从文件名中读取数据
 37 Status ReadTensorFromImageFile(string file_name, const int input_height,
 38                                const int input_width,
 39                                vector<Tensor>* out_tensors) {
 40     auto root = Scope::NewRootScope();
 41     using namespace ops;
 42 
 43     auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name);
 44     const int wanted_channels = 1;
 45     Output image_reader;
 46     std::size_t found = file_name.find(".png");
 47     //判断文件格式
 48     if (found!=std::string::npos) {
 49         image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels));
 50     } 
 51     else {
 52         image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels));
 53     }
 54     // 下面几步是读取图片并处理
 55     auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT);
 56     auto dims_expander = ExpandDims(root, float_caster, 0);
 57     auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width}));
 58     // Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std});
 59     Transpose(root.WithOpName("transpose"),resized,{0,2,1,3});
 60 
 61     GraphDef graph;
 62     root.ToGraphDef(&graph);
 63 
 64     unique_ptr<Session> session(NewSession(SessionOptions()));
 65     session->Create(graph);
 66     session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中
 67 
 68     return Status::OK();
 69 }
 70 
 71 int main(int argc, char* argv[]) {
 72 
 73     string graph_path = "aov_crnn.pb";
 74     GraphDef graph_def;
 75     //读取模型文件
 76     if (!ReadBinaryProto(Env::Default(), graph_path, &graph_def).ok()) {
 77         cout << "Read model .pb failed"<<endl;
 78         return -1;
 79     }
 80 
 81     //新建session
 82     unique_ptr<Session> session;
 83     SessionOptions sess_opt;
 84     sess_opt.config.mutable_gpu_options()->set_allow_growth(true);
 85     (&session)->reset(NewSession(sess_opt));
 86     if (!session->Create(graph_def).ok()) {
 87         cout<<"Create graph failed"<<endl;
 88         return -1;
 89     }
 90 
 91     //读取图像到inputs中
 92     int input_height = 40;
 93     int input_width = 240;
 94     vector<Tensor> inputs;
 95     // string image_path(argv[1]);
 96     string image_path("test.jpg");
 97     if (!ReadTensorFromImageFile(image_path, input_height, input_width,&inputs).ok()) {
 98         cout<<"Read image file failed"<<endl;
 99         return -1;
100     }
101 
102     vector<Tensor> outputs;
103     string input = "inputs_sq";
104     string output = "results_sq";//graph中的输入节点和输出节点,需要预先知道
105 
106     pair<string,Tensor>img(input,inputs[0]);
107     Status status = session->Run({img},{output}, {}, &outputs);//Run,得到运行结果,存到outputs中
108     if (!status.ok()) {
109         cout<<"Running model failed"<<endl;
110         cout<<status.ToString()<<endl;
111         return -1;
112     }
113 
114 
115     //得到模型运行结果
116     Tensor t = outputs[0];        
117     auto tmap = t.tensor<int64, 2>(); 
118     int output_dim = t.shape().dim_size(1); 
119 
120 
121     //预测结果解码为字符串
122     string res="";
123     for (int j = 0; j < output_dim; j++) {
124         res+=int2char[tmap(0,j)];
125     }
126     cout<<res<<endl;
127 
128     return 0;
129 }

g++ -g -D_GLIBCXX_USE_CXX11_ABI=0 tf_predict.cpp -o tf_predict -I /usr/include/eigen3 -I /usr/local/include/tf  -L/usr/local/lib/ `pkg-config --cflags --libs protobuf`  -ltensorflow_cc  -ltensorflow_framework

猜你喜欢

转载自www.cnblogs.com/buyizhiyou/p/10412967.html