Tensorflow移植到Android

手机调用TF模型的过程简介:
 
1、 保存训练完毕的TF模型 
2、 在Android项目中导入TF模型、导入Android平台调用TF模型需要的jar包和so文件 (它们负责TF模型的解析和运算) 
3、定义变量、存储数据,通过jar包提供的接口进行模型的调用

 
 

移植过程

 
我们以mnist数据集上自己训练的一个图像识别模型为例,进行讲解
 
一、 在使用python代码编写的TF模型定义中为模型的输入层和输出层Tensor Variable分别指定名字(通过形参 ‘name’)
 
X = tf.placeholder(tf.float32, shape = […], name=‘input’)  //网络的输入
Y = tf.nn.softmax(tf.matmul(f, out_weights) + out_biases, name=’output’)  //网络的输出
名字可以随便起,以方便好记为主,后面还会反复用到。我起的是input和output。
 
二、 将使用TensorFlow训练好的模型保存为.pb文件
 
在模型训练结束后的代码位置,添加下述两句代码,可将模型保存为.pb文件
 
output_graph_def = tf.graph_until.convert_variables_to_constants(session, session.graph_def, output_node_names=[‘output’])
//形参output_node_names用于指定输出的节点名称
with tf.gfile.FastGFile(model\mnist.pb, mode = ’wb’) as f:
    f.write(output_graph_def.SerializeToString())
 
第一个参数用于指定输出的文件存放路径、文件名及格式。我把它放在与代码同级目录的model文件下,取名为mnist.pb
 
第二个参数 mode用于指定文件操作的模式,’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
 
如果不指明‘b’,则默认会以文本txt方式写入文件。现在TF还不支持对文本格式.pb文件的解析,在调用时会出现报错。
 
注: 
1)、不能使用 tf.train.write_graph()保存模型,因为它只是保存了模型的结构,并不保存训练完毕的参数值 
2)、不能使用 tf.train.saver()保存模型,因为它只是保存了网络中的参数值,并不保存模型的结构。 
很显然,我们需要的是既保存模型的结构,又保存模型中每个参数的值。以上两者皆不符合。
 
五、添加资源到项目
 
1) 将(二)步生成的.pb文件放入项目中 
打开 Project view ,app/src/main/assets。 
若不存在assets目录,右键main->new->folder->Assets Folder
 
2) 添加(三)步生成的jar包 
打开Project view,将jar包拷贝到app->libs下 
选中jar文件,右键 add as library
 
3) 添加(三)生成的so文件 
打开 Project view,将.so文件拷贝到 app/src/main/jniLibs下(jniLibs文件夹若没有则新建)
 
如果我讲的不太明白的话,可自行谷歌搜索“如何在 Android studio中添加引用 jar文件和so文件”
 
六、创建接口,实现调用
 
1) 导入jar包和so文件 
在需要调用模型的.Java文件中,导入jar包:
import org.tensorflow.contrib.android.TensorFlowInferenceInterface
 
在该java类定义的首行,导入so文件:
{
    System.loadLibrary(“tensorflow_inference”)
}
 2)定义变量及对象
private static final String MODEL_FILE = “file:///android_asset/mnist.pb”   //模型存放路径
private static final String INPUT_NODE = “input”;       //模型中输入变量的名称
private static final String INPUT_NODE = “output”;  //模型中输出变量的名称
private static final int NUM_CLASSES = 10;  //样本集的类别数量,mnist数据集对应10
 
private static final int HEIGHT = 24;       //输入图片的像素高
private static final int WIDTH = 24;        //输入图片的像素宽
private static final int CHANNEL = 3;    //输入图片的通道数:RGB
 
private floats inputs = new float[HEIGHT*WIDTH*CHANNEL];    //用于存储的模型输入数据
private floats outputs = new float[NUM_CLASSES];    //用于存储模型的输出数据
2)Tensorflow 接口初始化
private TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface();   //接口定义
inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);  //接口初始化
 
 在完成上述两步之后,就可以反复调用模型。 
在每次调用前,先将待输入的数据按顺序存放进 inputs 变量中,然后执行下述三个语句。
 
3)TF模型的调用
 inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs);  //送入输入数据
inferenceInterface.runInference(new String[]{OUTPUT_NODE});     //进行模型的推理

inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs); //获取输出数据



实现Demo

分析源码
 
在Android中 native修饰的就是使用动态链接库中的接口,对于这个图片分类的demo,看了这写Java代码,我们可以找到tensorflow的3个接口如下:
 // load the tensorflow
  public native int initializeTensorFlow(
      AssetManager assetManager,
      String model,
      String labels,
      int numClasses,
      int inputSize,
      int imageMean,
      float imageStd,
      String inputName,
      String outputName);
  // classify the image by input the bitmap
  private native String classifyImageBmp(Bitmap bitmap);
  // classify the image by input the rgb 
  private native String classifyImageRgb(int[] output, int width, int height);
因此我们只需要学会使用这三个函数就能够将tensorflow移植我们的项目中了,对了,下面这条语句是载入动态链接库
 
我们到TensorFlowImageListener中找到了这几个函数的使用,因此,在使用时我们首先需要创建TensorFlowClassifier对象
private final TensorFlowClassifier tensorflow = new TensorFlowClassifier();
 
然后我们需要载入tensorflow模型,载入时需要以下几个参数,注意我在TensorFlowImageClassifier中已经将我暂时不需要的参数删除了
 
private static final int NUM_CLASSES = 1001;
  private static final int INPUT_SIZE = 224;
  private static final int IMAGE_MEAN = 117;
  private static final float IMAGE_STD = 1;
  private static final String INPUT_NAME = "input:0";
  private static final String OUTPUT_NAME = "output:0";
 
  private static final String MODEL_FILE =                "file:///android_asset/tensorflow_inception_graph.pb";
  private static final String LABEL_FILE =
      "file:///android_asset/imagenet_comp_graph_label_strings.txt
";
 不要忘了载入模型的第一个参数是assetManager,这个参数表示模型训练的数据结果(pb&&txt)文件的位置,如果为空的话会报出异常,上面的几个参数,基本看一下名字就知道是啥了,比如输入的名,图片的大小224*224等。模型初始化完成后就要对图片分类了,我们可以使用private native String classifyImageBmp(Bitmap bitmap);直接传入图片的bitmap位图,并且将图片大小调整为INPUT_SIZE即可
 
问题2:
 
应用在载入模型过程中闪退
 
看看你的assets目录位置对吗,也就是第一个参数,这个错了是无法载入模型的哦
在项目中使用ternsorflow
 
我直接将上面的demo中的主Activity清空,重新写了一个Activity,这个应用打开手机相册中的一张照片,然后将图片显示在界面上并且在最上面显示这个物品最可能的名字(使用谷歌的训练数据),这个项目依赖上面demo中的TensorflowClassifier类
 
主Activity代码:
package org.tensorflow.demo;
 
import java.io.FileNotFoundException;
import java.util.List;
 
import android.app.Activity;
import android.content.ContentResolver;
import android.content.Intent;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.net.Uri;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
 
public class CameraActivity extends Activity {
  /** Called when the activity is first created. */
  private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb";
  private static final String LABEL_FILE = "file:///android_asset/imagenet_comp_graph_label_strings.txt";
  private static final int NUM_CLASSES = 1001;
  private static final int INPUT_SIZE = 224;
  private static final int IMAGE_MEAN = 117;
  private static final float IMAGE_STD = 1;
  private static final String INPUT_NAME = "input:0";
  private static final String OUTPUT_NAME = "output:0";
  private final TensorFlowClassifier tensorflow = new TensorFlowClassifier();
  private TextView mResultText;
 
 
 
  @Override
  public void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_camera);
 
 
    // test1 load tensorflow
    final AssetManager assetManager = getAssets();
    tensorflow.initializeTensorFlow(
            assetManager, MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD,
            INPUT_NAME, OUTPUT_NAME);
 
    // test1 end
 
 
    Button button = (Button)findViewById(R.id.b01);
    button.setText("选择图片");
    button.setOnClickListener(new Button.OnClickListener(){
      @Override
      public void onClick(View v) {
        Intent intent = new Intent();
                /* 开启Pictures画面Type设定为image */
        intent.setType("image/*");
                /* 使用Intent.ACTION_GET_CONTENT这个Action */
        intent.setAction(Intent.ACTION_GET_CONTENT);
                /* 取得相片后返回本画面 */
        startActivityForResult(intent, 1);
      }
 
    });
  }
 
  @Override
  protected void onActivityResult(int requestCode, int resultCode, Intent data) {
    if (resultCode == RESULT_OK) {
      Uri uri = data.getData();
      Log.e("uri", uri.toString());
      ContentResolver cr = this.getContentResolver();
      try {
        Bitmap bitmap = BitmapFactory.decodeStream(cr.openInputStream(uri));
        dealPics(bitmap);
      } catch (FileNotFoundException e) {
        Log.e("Exception", e.getMessage(),e);
      }
    }
    super.onActivityResult(requestCode, resultCode, data);
  }
  private void dealPics(Bitmap bitmap) {
    ImageView imageView = (ImageView) findViewById(R.id.iv01);
                /* 将Bitmap设定到ImageView */
 
    int width = bitmap.getWidth();
 
    int height = bitmap.getHeight();
    System.out.println(width + "&&" + height);
    float scaleWidth = ((float)INPUT_SIZE) / width;
 
    float scaleHeight = ((float) INPUT_SIZE) / height;
    Matrix matrix = new Matrix();
 
    matrix.postScale(scaleWidth, scaleHeight);
    Bitmap newbm = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
 
    imageView.setImageBitmap(newbm);
    final List<Classifier.Recognition> results = tensorflow.recognizeImage(newbm);
    for (final Classifier.Recognition result : results) {
      System.out.println("Result: " + result.getTitle());
    }
    mResultText = (TextView)findViewById(R.id.t01);
    mResultText.setText("Detected = " + results.get(0).getTitle());
 
    System.out.println(newbm.getWidth() + "&&" + newbm.getHeight());
 
 
  }
}
 

 
  JNI 代码



#include "tensorflow/examples/android/jni/tensorflow_jni.h"


#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <android/bitmap.h>


#include <jni.h>
#include <pthread.h>
#include <sys/stat.h>
#include <unistd.h>
#include <queue>
#include <sstream>
#include <string>


#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/examples/android/jni/jni_utils.h"


using namespace tensorflow;


// Global variables that holds the TensorFlow classifier.
static std::unique_ptr<tensorflow::Session> session;


static std::vector<std::string> g_label_strings;
static bool g_compute_graph_initialized = false;
// static mutex g_compute_graph_mutex(base::LINKER_INITIALIZED);


static int g_tensorflow_input_size;  // The image size for the model input.
static int g_image_mean;             // The image mean.
static float g_image_std;            // The scale value for the input image.
static std::unique_ptr<std::string> g_input_name;
static std::unique_ptr<std::string> g_output_name;
static std::unique_ptr<StatSummarizer> g_stats;


// For basic benchmarking.
static int g_num_runs = 0;
static int64 g_timing_total_us = 0;
static Stat<int64> g_frequency_start;
static Stat<int64> g_frequency_end;


#ifdef LOG_DETAILED_STATS
static const bool kLogDetailedStats = true;
#else
static const bool kLogDetailedStats = false;
#endif


// Improve benchmarking by limiting runs to predefined amount.
// 0 (default) denotes infinite runs.
#ifndef MAX_NUM_RUNS
#define MAX_NUM_RUNS 0
#endif


#ifdef SAVE_STEP_STATS
static const bool kSaveStepStats = true;
#else
static const bool kSaveStepStats = false;
#endif


inline static int64 CurrentThreadTimeUs() {
  struct timeval tv;
  gettimeofday(&tv, NULL);
  return tv.tv_sec * 1000000 + tv.tv_usec;
}


JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
    JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model,
    jstring labels, jint num_classes, jint model_input_size, jint image_mean,
    jfloat image_std, jstring input_name, jstring output_name) {
  g_num_runs = 0;
  g_timing_total_us = 0;
  g_frequency_start.Reset();
  g_frequency_end.Reset();


  // MutexLock input_lock(&g_compute_graph_mutex);
  if (g_compute_graph_initialized) {
    LOG(INFO) << "Compute graph already loaded. skipping.";
    return 0;
  }


  const int64 start_time = CurrentThreadTimeUs();


  const char* const model_cstr = env->GetStringUTFChars(model, NULL);
  const char* const labels_cstr = env->GetStringUTFChars(labels, NULL);


  g_tensorflow_input_size = model_input_size;
  g_image_mean = image_mean;
  g_image_std = image_std;
  g_input_name.reset(new std::string(env->GetStringUTFChars(input_name, NULL)));
  g_output_name.reset(
      new std::string(env->GetStringUTFChars(output_name, NULL)));


  LOG(INFO) << "Loading TensorFlow.";


  LOG(INFO) << "Making new SessionOptions.";
  tensorflow::SessionOptions options;
  tensorflow::ConfigProto& config = options.config;
  LOG(INFO) << "Got config, " << config.device_count_size() << " devices";


  session.reset(tensorflow::NewSession(options));
  LOG(INFO) << "Session created.";


  tensorflow::GraphDef tensorflow_graph;
  LOG(INFO) << "Graph created.";


  AAssetManager* const asset_manager =
      AAssetManager_fromJava(env, java_asset_manager);
  LOG(INFO) << "Acquired AssetManager.";


  LOG(INFO) << "Reading file to proto: " << model_cstr;
  ReadFileToProto(asset_manager, model_cstr, &tensorflow_graph);


  g_stats.reset(new StatSummarizer(tensorflow_graph));


  LOG(INFO) << "Creating session.";
  tensorflow::Status s = session->Create(tensorflow_graph);
  if (!s.ok()) {
    LOG(FATAL) << "Could not create TensorFlow Graph: " << s;
  }


  // Clear the proto to save memory space.
  tensorflow_graph.Clear();
  LOG(INFO) << "TensorFlow graph loaded from: " << model_cstr;


  // Read the label list
  ReadFileToVector(asset_manager, labels_cstr, &g_label_strings);
  LOG(INFO) << g_label_strings.size()
            << " label strings loaded from: " << labels_cstr;
  g_compute_graph_initialized = true;


  const int64 end_time = CurrentThreadTimeUs();
  LOG(INFO) << "Initialization done in " << (end_time - start_time) / 1000
            << "ms";


  return 0;
}


namespace {
typedef struct {
  uint8 red;
  uint8 green;
  uint8 blue;
  uint8 alpha;
} RGBA;
}  // namespace


// Returns the top N confidence values over threshold in the provided vector,
// sorted by confidence in descending order.
static void GetTopN(
    const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>,
                           Eigen::Aligned>& prediction,
    const int num_results, const float threshold,
    std::vector<std::pair<float, int> >* top_results) {
  // Will contain top N results in ascending order.
  std::priority_queue<std::pair<float, int>,
                      std::vector<std::pair<float, int> >,
                      std::greater<std::pair<float, int> > >
      top_result_pq;


  const int count = prediction.size();
  for (int i = 0; i < count; ++i) {
    const float value = prediction(i);


    // Only add it if it beats the threshold and has a chance at being in
    // the top N.
    if (value < threshold) {
      continue;
    }


    top_result_pq.push(std::pair<float, int>(value, i));


    // If at capacity, kick the smallest value out.
    if (top_result_pq.size() > num_results) {
      top_result_pq.pop();
    }
  }


  // Copy to output vector and reverse into descending order.
  while (!top_result_pq.empty()) {
    top_results->push_back(top_result_pq.top());
    top_result_pq.pop();
  }
  std::reverse(top_results->begin(), top_results->end());
}


static int64 GetCpuSpeed() {
  string scaling_contents;
  ReadFileToString(nullptr,
                   "/sys/devices/system/cpu/cpu0/cpufreq/scaling_cur_freq",
                   &scaling_contents);
  std::stringstream ss(scaling_contents);
  int64 result;
  ss >> result;
  return result;
}


static std::string ClassifyImage(const RGBA* const bitmap_src) {
  // Force the app to quit if we've reached our run quota, to make
  // benchmarks more reproducible.
  if (MAX_NUM_RUNS > 0 && g_num_runs >= MAX_NUM_RUNS) {
    LOG(INFO) << "Benchmark complete. "
              << (g_timing_total_us / g_num_runs / 1000) << "ms/run avg over "
              << g_num_runs << " runs.";
    LOG(INFO) << "";
    exit(0);
  }


  ++g_num_runs;


  // Create input tensor
  tensorflow::Tensor input_tensor(
      tensorflow::DT_FLOAT,
      tensorflow::TensorShape(
          {1, g_tensorflow_input_size, g_tensorflow_input_size, 3}));


  auto input_tensor_mapped = input_tensor.tensor<float, 4>();


  LOG(INFO) << "TensorFlow: Copying Data.";
  for (int i = 0; i < g_tensorflow_input_size; ++i) {
    const RGBA* src = bitmap_src + i * g_tensorflow_input_size;
    for (int j = 0; j < g_tensorflow_input_size; ++j) {
      // Copy 3 values
      input_tensor_mapped(0, i, j, 0) =
          (static_cast<float>(src->red) - g_image_mean) / g_image_std;
      input_tensor_mapped(0, i, j, 1) =
          (static_cast<float>(src->green) - g_image_mean) / g_image_std;
      input_tensor_mapped(0, i, j, 2) =
          (static_cast<float>(src->blue) - g_image_mean) / g_image_std;
      ++src;
    }
  }


  std::vector<std::pair<std::string, tensorflow::Tensor> > input_tensors(
      {{*g_input_name, input_tensor}});


  VLOG(0) << "Start computing.";
  std::vector<tensorflow::Tensor> output_tensors;
  std::vector<std::string> output_names({*g_output_name});


  tensorflow::Status s;
  int64 start_time, end_time;


  if (kLogDetailedStats || kSaveStepStats) {
    RunOptions run_options;
    run_options.set_trace_level(RunOptions::FULL_TRACE);
    RunMetadata run_metadata;
    g_frequency_start.UpdateStat(GetCpuSpeed());
    start_time = CurrentThreadTimeUs();
    s = session->Run(run_options, input_tensors, output_names, {},
                     &output_tensors, &run_metadata);
    end_time = CurrentThreadTimeUs();
    g_frequency_end.UpdateStat(GetCpuSpeed());
    assert(run_metadata.has_step_stats());


    const StepStats& stats = run_metadata.step_stats();


    if (kLogDetailedStats) {
      LOG(INFO) << "CPU frequency start: " << g_frequency_start;
      LOG(INFO) << "CPU frequency end:   " << g_frequency_end;
      g_stats->ProcessStepStats(stats);
      g_stats->PrintStepStats();
    }


    if (kSaveStepStats) {
      mkdir("/sdcard/tf/", 0755);
      const string filename =
          strings::Printf("/sdcard/tf/stepstats%05d.pb", g_num_runs);
      WriteProtoToFile(filename.c_str(), stats);
    }
  } else {
    start_time = CurrentThreadTimeUs();
    s = session->Run(input_tensors, output_names, {}, &output_tensors);
    end_time = CurrentThreadTimeUs();
  }
  const int64 elapsed_time_inf = end_time - start_time;
  g_timing_total_us += elapsed_time_inf;
  VLOG(0) << "End computing. Ran in " << elapsed_time_inf / 1000 << "ms ("
          << (g_timing_total_us / g_num_runs / 1000) << "ms avg over "
          << g_num_runs << " runs)";


  if (!s.ok()) {
    LOG(FATAL) << "Error during inference: " << s;
  }


  VLOG(0) << "Reading from layer " << output_names[0];
  tensorflow::Tensor* output = &output_tensors[0];
  const int kNumResults = 5;
  const float kThreshold = 0.1f;
  std::vector<std::pair<float, int> > top_results;
  GetTopN(output->flat<float>(), kNumResults, kThreshold, &top_results);


  std::stringstream ss;
  ss.precision(3);
  for (const auto& result : top_results) {
    const float confidence = result.first;
    const int index = result.second;


    ss << index << " " << confidence << " ";


    // Write out the result as a string
    if (index < g_label_strings.size()) {
      // just for safety: theoretically, the output is under 1000 unless there
      // is some numerical issues leading to a wrong prediction.
      ss << g_label_strings[index];
    } else {
      ss << "Prediction: " << index;
    }


    ss << "\n";
  }


  LOG(INFO) << "Predictions: " << ss.str();
  return ss.str();
}


JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(classifyImageRgb)(
    JNIEnv* env, jobject thiz, jintArray image, jint width, jint height) {
  // Copy image into currFrame.
  jboolean iCopied = JNI_FALSE;
  jint* pixels = env->GetIntArrayElements(image, &iCopied);


  std::string result = ClassifyImage(reinterpret_cast<const RGBA*>(pixels));


  env->ReleaseIntArrayElements(image, pixels, JNI_ABORT);


  return env->NewStringUTF(result.c_str());
}


JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(classifyImageBmp)(JNIEnv* env,
                                                              jobject thiz,
                                                              jobject bitmap) {
  // Obtains the bitmap information.
  AndroidBitmapInfo info;
  CHECK_EQ(AndroidBitmap_getInfo(env, bitmap, &info),
           ANDROID_BITMAP_RESULT_SUCCESS);
  void* pixels;
  CHECK_EQ(AndroidBitmap_lockPixels(env, bitmap, &pixels),
           ANDROID_BITMAP_RESULT_SUCCESS);
  LOG(INFO) << "Image dimensions: " << info.width << "x" << info.height
            << " stride: " << info.stride;
  // TODO(andrewharp): deal with other formats if necessary.
  if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) {
    LOG(FATAL) << "Only RGBA_8888 Bitmaps are supported.";
  }


  std::string result = ClassifyImage(static_cast<const RGBA*>(pixels));


  // Finally, unlock the pixels
  CHECK_EQ(AndroidBitmap_unlockPixels(env, bitmap),
           ANDROID_BITMAP_RESULT_SUCCESS);


  return env->NewStringUTF(result.c_str());
}
 
 

猜你喜欢

转载自blog.csdn.net/u011808673/article/details/78573114