tensorflow移植到Android端,实现物体检测自动拍照

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_27063119/article/details/79926289
tensorflow移植到Android端实现物体检测

一. 说明

1. tensorflow是什么:

    是谷歌基于DistBelief进行研发的第二代人工智能学习系统。

2. 为什么要使用tensorflow在手机端进行物体检测:

    前一篇博客中讲到,将opencv移植到android中,检测到指定物体自动拍照,虽然说是功能确实可以实现,但是opencv毕竟比较落后了,识别的错误率还是很大的,于是便引入了tensorflow进行物体分类检测,在很大程度上提高了识别率以及正确率。

本篇博客主要讲解一下,tensorflow是怎样移植到手机端的,以及自定义的模型是怎样训练的。

二. 准备工作

1. 下载tensorflow项目(地址:https://github.com/tensorflow/tensorflow)

2. python环境(我是用的是python3.5的)

3. tensorflow安装(有两个版本,cpu版,gpu版,显然gpu训练模型时要快很多,这个视电脑配置而定吧,一般安装anaconda,直接执行:conda install tensorflow安装即可,我使用了tensorflow-gpu版本,需要nvidia显卡支持,命令:conda install tensorflow-gpu)

4. Android Studio (我使用了2.3.3版本)

5. 下载 libtensorflow_inference.so 以及 libandroid_tensorflow_inference_java.jar文件,(这两个文件可以使用源码进行编译生成),链接:https://pan.baidu.com/s/1tN_nNqfy6JC272J17VaWTg 密码:boat

三. 训练自定义的tensorflow模型

1. 准备数据集

使用tensorflow训练模型,该分类的类别数必须大于等于2的(没有背景这一类别),举例:

如果需要进行识别人和狗两种类别,那么:需要准备图片(只有狗在里面的)100来张,放入dogs文件夹,同时准备图片(只有人在里面的)100来张,放入peoples文件夹,图片越多训练出来的模型越精确,每个类别100来张只能说勉强够用,

将两个类别的文件夹放置:

tensorflow_master/tensorflow/examples/image_retraining/data/train 中

data/train文件夹没有的话,自行新建

2. 准备预训练模型

训练模型需要用到imagenet预训练权重,4个文件(classify_image_graph_def.pb,imagenet_2012_challenge_label_map_proto.pbtxt,imagenet_synset_to_human_label_map.txt,inception-2015-12-05.tgz),下载链接:链接:https://pan.baidu.com/s/1JlDbYy4NHD7qy3Or5lDtSg 密码:i3jo

提前下载拷贝至 model文件夹下,没有该文件夹请自行新建,否则会自动下载很慢的

3. 开始训练

cd 进入tensorflow_master\tensorflow\examples\image_retraining文件夹:

执行命令:

python retrain.py --bottleneck_dir bottleneck --how_many_training_steps 4000 --model_dir model/ --output_graph output_graph.pb --output_labels output_labels.txt --image_dir data/train/

执行完毕会在tensorflow_master\tensorflow\examples\image_retraining文件夹下生成两个文件:

output_graph.pb  以及    output_labels.txt

4. 上一步骤中生成的模型不能直接放置到Android中,需要一步转化:官方的解释:

To use v3 Inception model, strip the DecodeJpeg Op from your retrained
  // model first:

cd 进入tensorflow_master\tensorflow\python\tools文件夹,将上步中生成的 output_graph.pb 文件复制到改目录下,执行命令:

python strip_unused.py --input_graph=output_graph.pb --output_graph=output.pb --input_node_names="Mul" --output_node_names="final_result" --input_binary=true

即可在改目录下生成 output.pb 文件。

至此,模型训练完毕。

四. 整合Android项目

1. 新建项目后,在\app\src\main目录下 新建assets以及jniLibs两个目录,将之前生成的 output.pb 以及 output_labels.txt文件拷贝至assets文件夹下

2. 在jniLibs文件夹下新建armeabi-v7a 文件夹,将 libtensorflow_inference.so 拷贝至 jniLibs\armeabi-v7a 文件夹下

3. 将libandroid_tensorflow_inference_java.jar 添加至项目中,不会的直接搜索 Android Studio添加jar。

4. 新建一个类(Classifier.Java):

import android.graphics.Bitmap;
import android.graphics.RectF;

import java.util.List;

/**
 * Created by amitshekhar on 06/03/17.
 */

/**
 * Generic interface for interacting with different recognition engines.
 */
public interface Classifier {
    /**
     * An immutable result returned by a Classifier describing what was recognized.
     */
    public class Recognition {
        /**
         * A unique identifier for what has been recognized. Specific to the class, not the instance of
         * the object.
         */
        private final String id;

        /**
         * Display name for the recognition.
         */
        private final String title;

        /**
         * A sortable score for how good the recognition is relative to others. Higher should be better.
         */
        private final Float confidence;

        /**
         * Optional location within the source image for the location of the recognized object.
         */
        private RectF location;

        public Recognition(
                final String id, final String title, final Float confidence, final RectF location) {
            this.id = id;
            this.title = title;
            this.confidence = confidence;
            this.location = location;
        }

        public String getId() {
            return id;
        }

        public String getTitle() {
            return title;
        }

        public Float getConfidence() {
            return confidence;
        }

        public RectF getLocation() {
            return new RectF(location);
        }

        public void setLocation(RectF location) {
            this.location = location;
        }

        @Override
        public String toString() {
            String resultString = "";
            if (id != null) {
                resultString += "[" + id + "] ";
            }

            if (title != null) {
                resultString += title + " ";
            }

            if (confidence != null) {
                resultString += String.format("(%.1f%%) ", confidence * 100.0f);
            }

            if (location != null) {
                resultString += location + " ";
            }

            return resultString.trim();
        }
    }

    List<Recognition> recognizeImage(Bitmap bitmap);

    void enableStatLogging(final boolean debug);

    String getStatString();

    void close();
}

5. 新建识别实现类 ( TensorFlowImageClassifier.Java)
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.support.v4.os.TraceCompat;
import android.util.Log;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;

import www.demo04.com.util.tensorflow.Classifier;

/**
 * Created by amitshekhar on 06/03/17.
 */

/**
 * A classifier specialized to label images using TensorFlow.
 */
public class TensorFlowImageClassifier implements Classifier {

    private static final String TAG = "ImageClassifier";

    // Only return this many results with at least this confidence.
    private static final int MAX_RESULTS = 2;
    private static final float THRESHOLD = 0.1f;

    // Config values.
    private String inputName;
    private String outputName;
    private int inputSize;
    private int imageMean;
    private float imageStd;

    // Pre-allocated buffers.
    private Vector<String> labels = new Vector<String>();
    private int[] intValues;
    private float[] floatValues;
    private float[] outputs;
    private String[] outputNames;

    private TensorFlowInferenceInterface inferenceInterface;

    private boolean runStats = false;

    private TensorFlowImageClassifier() {
    }

    /**
     * Initializes a native TensorFlow session for classifying images.
     *
     * @param assetManager  The asset manager to be used to load assets.
     * @param modelFilename The filepath of the model GraphDef protocol buffer.
     * @param labelFilename The filepath of label file for classes.
     * @param inputSize     The input size. A square image of inputSize x inputSize is assumed.
     * @param imageMean     The assumed mean of the image values.
     * @param imageStd      The assumed std of the image values.
     * @param inputName     The label of the image input node.
     * @param outputName    The label of the output node.
     * @throws IOException
     */
    public static Classifier create(
            AssetManager assetManager,
            String modelFilename,
            String labelFilename,
            int inputSize,
            int imageMean,
            float imageStd,
            String inputName,
            String outputName)
            throws IOException {
        TensorFlowImageClassifier c = new TensorFlowImageClassifier();
        c.inputName = inputName;
        c.outputName = outputName;

        // Read the label names into memory.
        // TODO(andrewharp): make this handle non-assets.
        String actualFilename = labelFilename.split("file:///android_asset/")[1];
        Log.i(TAG, "Reading labels from: " + actualFilename);
        BufferedReader br = null;
        br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
        String line;
        while ((line = br.readLine()) != null) {
            c.labels.add(line);
        }
        br.close();

        c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
        // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
        int numClasses =
                (int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
        Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);

        // Ideally, inputSize could have been retrieved from the shape of the input operation.  Alas,
        // the placeholder node for input in the graphdef typically used does not specify a shape, so it
        // must be passed in as a parameter.
        c.inputSize = inputSize;
        c.imageMean = imageMean;
        c.imageStd = imageStd;

        // Pre-allocate buffers.
        c.outputNames = new String[]{outputName};
        c.intValues = new int[inputSize * inputSize];
        c.floatValues = new float[inputSize * inputSize * 3];
        c.outputs = new float[numClasses];


        /*if(c.inferenceInterface != null && c.inferenceInterface.graph() != null && c.inferenceInterface.graph().operations()!=null){
            Iterator<Operation> operations = c.inferenceInterface.graph().operations();
            while(operations.hasNext()){
                Log.e("operation : ",""+operations.next().name());
            }
        }*/


        return c;
    }

    @Override
    public List<Recognition> recognizeImage(final Bitmap bitmap) {
        // Log this method so that it can be analyzed with systrace.
        TraceCompat.beginSection("recognizeImage");

        TraceCompat.beginSection("preprocessBitmap");
        // Preprocess the image data from 0-255 int to normalized float based
        // on the provided parameters.
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        for (int i = 0; i < intValues.length; ++i) {
            final int val = intValues[i];
            floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
            floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
            floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
        }
        TraceCompat.endSection();

        // Copy the input data into TensorFlow.
        TraceCompat.beginSection("feed");
        inferenceInterface.feed(
                inputName, floatValues, new long[]{1, inputSize, inputSize, 3});
        TraceCompat.endSection();

        // Run the inference call.
        TraceCompat.beginSection("run");
        inferenceInterface.run(outputNames, runStats);
        TraceCompat.endSection();

        // Copy the output Tensor back into the output array.
        TraceCompat.beginSection("fetch");
        inferenceInterface.fetch(outputName, outputs);
        TraceCompat.endSection();

        // Find the best classifications.
        PriorityQueue<Recognition> pq =
                new PriorityQueue<Recognition>(
                        3,
                        new Comparator<Recognition>() {
                            @Override
                            public int compare(Recognition lhs, Recognition rhs) {
                                // Intentionally reversed to put high confidence at the head of the queue.
                                return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                            }
                        });
        for (int i = 0; i < outputs.length; ++i) {
            if (outputs[i] > THRESHOLD) {
                pq.add(
                        new Recognition(
                                "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
            }
        }
        final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
        int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
        for (int i = 0; i < recognitionsSize; ++i) {
            recognitions.add(pq.poll());
        }
        TraceCompat.endSection(); // "recognizeImage"
        return recognitions;
    }

    @Override
    public void enableStatLogging(boolean debug) {
        runStats = debug;
    }

    @Override
    public String getStatString() {
        return inferenceInterface.getStatString();
    }

    @Override
    public void close() {
        inferenceInterface.close();
    }
}


5. 在识别的Activity中

定义一些变量:

    private static final int INPUT_SIZE = 299;
    private static final int IMAGE_MEAN = 128;
    private static final float IMAGE_STD = 128;

    private static final String INPUT_NAME = "Mul";
    private static final String OUTPUT_NAME = "final_result";
    private static final String MODEL_FILE = "file:///android_asset/output.pb";
    private static final String LABEL_FILE ="file:///android_asset/output_labels.txt";


添加初始化 tensorflow 方法:

private void initTensorFlowAndLoadModel() {
        executor.execute(new Runnable() {
            @Override
            public void run() {
                try {
                    classifier = TensorFlowImageClassifier.create(
                            getAssets(),
                            MODEL_FILE,
                            LABEL_FILE,
                            INPUT_SIZE,
                            IMAGE_MEAN,
                            IMAGE_STD,
                            INPUT_NAME,
                            OUTPUT_NAME);
                } catch (final Exception e) {
                    throw new RuntimeException("Error initializing TensorFlow!", e);
                }
            }
        });
    }


这里的图片官方说法是使用299 * 299的,其他规格大小试了几个都有问题,有的大了,有的提示不是2048的倍数,总之不想一直纠结,可以将图片裁剪一下,一句话代码:

rightBitmap = Bitmap.createScaledBitmap(rightBitmap, 299, 299, true);

开始识别,直接调用即可:

final List<Classifier.Recognition> results = classifier.recognizeImage(rightBitmap);

返回的 results 是一个List集合,存放有预测物体的名称,以及预测的准确率

可以发现比之前的opencv 准确多了。

Android程序截图如下,与之前的Android项目类似,只是替换了识别地方的代码:

    


最后附上Android源码的下载地址(由于项目过大,因此不含tensorflow的模型):

https://download.csdn.net/download/qq_27063119/10346591






猜你喜欢

转载自blog.csdn.net/qq_27063119/article/details/79926289