tensorflow python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测
1 #include <iostream> 2 #include <map> 3 4 #include "tensorflow/cc/ops/const_op.h" 5 #include "tensorflow/cc/ops/image_ops.h" 6 #include "tensorflow/cc/ops/standard_ops.h" 7 #include "tensorflow/core/framework/graph.pb.h" 8 #include "tensorflow/core/framework/tensor.h" 9 #include "tensorflow/core/graph/default_device.h" 10 #include "tensorflow/core/graph/graph_def_builder.h" 11 #include "tensorflow/core/lib/core/errors.h" 12 #include "tensorflow/core/lib/core/stringpiece.h" 13 #include "tensorflow/core/lib/core/threadpool.h" 14 #include "tensorflow/core/lib/io/path.h" 15 #include "tensorflow/core/lib/strings/stringprintf.h" 16 #include "tensorflow/core/platform/init_main.h" 17 #include "tensorflow/core/platform/logging.h" 18 #include "tensorflow/core/platform/types.h" 19 #include "tensorflow/core/public/session.h" 20 #include "tensorflow/core/util/command_line_flags.h" 21 22 using namespace std ; 23 using namespace tensorflow; 24 using tensorflow::Flag; 25 using tensorflow::Tensor; 26 using tensorflow::Status; 27 using tensorflow::string; 28 using tensorflow::int32; 29 30 map<int,string> int2char; 31 string s = "KDA0123456789 "; 32 for(int i=0;i<s.size();i++){ 33 int2char[i]=s[i]; 34 } 35 36 //从文件名中读取数据 37 Status ReadTensorFromImageFile(string file_name, const int input_height, 38 const int input_width, 39 vector<Tensor>* out_tensors) { 40 auto root = Scope::NewRootScope(); 41 using namespace ops; 42 43 auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name); 44 const int wanted_channels = 1; 45 Output image_reader; 46 std::size_t found = file_name.find(".png"); 47 //判断文件格式 48 if (found!=std::string::npos) { 49 image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels)); 50 } 51 else { 52 image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels)); 53 } 54 // 下面几步是读取图片并处理 55 auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT); 56 auto dims_expander = ExpandDims(root, float_caster, 0); 57 auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width})); 58 // Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std}); 59 Transpose(root.WithOpName("transpose"),resized,{0,2,1,3}); 60 61 GraphDef graph; 62 root.ToGraphDef(&graph); 63 64 unique_ptr<Session> session(NewSession(SessionOptions())); 65 session->Create(graph); 66 session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中 67 68 return Status::OK(); 69 } 70 71 int main(int argc, char* argv[]) { 72 73 string graph_path = "aov_crnn.pb"; 74 GraphDef graph_def; 75 //读取模型文件 76 if (!ReadBinaryProto(Env::Default(), graph_path, &graph_def).ok()) { 77 cout << "Read model .pb failed"<<endl; 78 return -1; 79 } 80 81 //新建session 82 unique_ptr<Session> session; 83 SessionOptions sess_opt; 84 sess_opt.config.mutable_gpu_options()->set_allow_growth(true); 85 (&session)->reset(NewSession(sess_opt)); 86 if (!session->Create(graph_def).ok()) { 87 cout<<"Create graph failed"<<endl; 88 return -1; 89 } 90 91 //读取图像到inputs中 92 int input_height = 40; 93 int input_width = 240; 94 vector<Tensor> inputs; 95 // string image_path(argv[1]); 96 string image_path("test.jpg"); 97 if (!ReadTensorFromImageFile(image_path, input_height, input_width,&inputs).ok()) { 98 cout<<"Read image file failed"<<endl; 99 return -1; 100 } 101 102 vector<Tensor> outputs; 103 string input = "inputs_sq"; 104 string output = "results_sq";//graph中的输入节点和输出节点,需要预先知道 105 106 pair<string,Tensor>img(input,inputs[0]); 107 Status status = session->Run({img},{output}, {}, &outputs);//Run,得到运行结果,存到outputs中 108 if (!status.ok()) { 109 cout<<"Running model failed"<<endl; 110 cout<<status.ToString()<<endl; 111 return -1; 112 } 113 114 115 //得到模型运行结果 116 Tensor t = outputs[0]; 117 auto tmap = t.tensor<int64, 2>(); 118 int output_dim = t.shape().dim_size(1); 119 120 121 //预测结果解码为字符串 122 string res=""; 123 for (int j = 0; j < output_dim; j++) { 124 res+=int2char[tmap(0,j)]; 125 } 126 cout<<res<<endl; 127 128 return 0; 129 }
g++ -g -D_GLIBCXX_USE_CXX11_ABI=0 tf_predict.cpp -o tf_predict -I /usr/include/eigen3 -I /usr/local/include/tf -L/usr/local/lib/ `pkg-config --cflags --libs protobuf` -ltensorflow_cc -ltensorflow_framework