Android TensorFlow Lite 初探 数字分类器(JAVA DEMO)

TensorFlow Lite

TensorFlow Lite 是一种用于设备端推断的开源深度学习框架,
在移动设备和 IoT 设备上部署机器学习模型

环境

AndroidStudio 4.0 + JAVA

数字分类器

通过 TensorFlow Lite 模型对手写数字进行分类。
图片来自tflite-mnist-android

关于DEMO

历程

刚开始并没有找到图片中的DEMO源码, 于是自己根据kotlin的DEMO移植了一下, 以下是移植的过程, 若对TensorFlow Lite已有所了解, 请自行跳过.

  1. 在AS中新建Module DigitClassifierByTFL.
  • 编译环境在这里插入图片描述
  • build.gradle中SDK相关配置:
android {
    compileSdkVersion 30
    buildToolsVersion "30.0.2"

    defaultConfig {
        applicationId "com.ansondroider.digitclassifierbytfl"
        minSdkVersion 16
        targetSdkVersion 16
        versionCode 1
        versionName "1.0"
    }
}
  • 目录结构
    在这里插入图片描述
  1. 等待构建完成, 需修改一些配置:
  • build.gradle: 不压缩 .tflite 文件, 若不加会因为导入模型有问题导致运行出错
    aaptOptions {
        noCompress "tflite"
    }
  • build.gradle: 增加 TensorFlow Lite依赖
dependencies {
    implementation fileTree(dir: "libs", include: ["*.jar"])
    implementation ('org.tensorflow:tensorflow-lite:0.0.0-nightly'){changing = true}
}
  1. 源码及说明
  • layout
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent">
    <com.ansondroider.digitclassifierbytfl.PaintView
        android:layout_width="400dp"
        android:layout_height="400dp"
        android:layout_centerHorizontal="true"
        android:id="@+id/paintView"/>
    <TextView
        android:id="@+id/tvRes"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_alignParentBottom="true"
        android:textColor="#FF00FF00"
        android:text="Result: ?"
        android:textSize="30sp"/>

</RelativeLayout>

PaintView: 手指绘画.
TextView: 显示结果.
在这里插入图片描述

  • Activity
package com.ansondroider.digitclassifierbytfl;

import android.Manifest;
import android.app.Activity;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.os.Build;
import android.os.Bundle;
import android.widget.TextView;

import org.tensorflow.lite.Interpreter;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;

public class DigitClassifierByTFL extends Activity {
    
    
  PaintView paintView;
  TextView tvRes;
  Classifier fier;
  @Override
  protected void onCreate(Bundle savedInstanceState) {
    
    
      super.onCreate(savedInstanceState);
      //初始化UI
      setContentView(R.layout.activity_digit_classifier_by_tfl);
      tvRes = (TextView)findViewById(R.id.tvRes);
      paintView = (PaintView)findViewById(R.id.paintView);
      //添加PaintView回调, 当手指绘画完成后, 立即调用分类器进行分类.
      //绘画完成: 当手指抬起后 500 毫秒.
      paintView.setCallback(new PaintView.Callback() {
    
    
          @Override
          public void onWriteDone() {
    
    
              final String res = fier.classifier(paintView.getBitmap());
              tvRes.post(new Runnable() {
    
    
                  @Override
                  public void run() {
    
    
                      tvRes.setText(res);
                  }
              });
          }
      });
      //创建分类器
      fier = new Classifier(getAssets());
      if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
    
    
          requestPermissions(new String[]{
    
    Manifest.permission.WRITE_EXTERNAL_STORAGE, Manifest.permission.READ_EXTERNAL_STORAGE}, 0x77);
      }
  }

  @Override
  protected void onDestroy() {
    
    
      super.onDestroy();
      fier.close();
  }
  • PaintView
package com.ansondroider.digitclassifierbytfl;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Path;
import android.util.AttributeSet;
import android.view.MotionEvent;
import android.view.View;

public class PaintView extends View {
    
    
  public PaintView(Context context) {
    
    
      super(context);
  }

  public PaintView(Context context, AttributeSet attrs) {
    
    
      super(context, attrs);
  }

  public PaintView(Context context, AttributeSet attrs, int defStyleAttr) {
    
    
      super(context, attrs, defStyleAttr);
  }

  Bitmap bm;
  public Bitmap getBitmap(){
    
    
      return bm;
  }

  @Override
  protected void onSizeChanged(int w, int h, int oldw, int oldh) {
    
    
      super.onSizeChanged(w, h, oldw, oldh);
      if(bm != null){
    
    
          bm.recycle();
      }

      bm = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
      cBm = new Canvas(bm);
      cBm.drawColor(Color.BLACK);
  }
  Canvas cBm;
  Paint p = new Paint(Paint.ANTI_ALIAS_FLAG);

  float dx, dy, cx, cy;
  @Override
  public boolean onTouchEvent(MotionEvent event) {
    
    
      cx = event.getX();
      cy = event.getY();
      switch(event.getAction()){
    
    
          case MotionEvent.ACTION_DOWN:
              dx = cx;
              dy = cy;
              startWrite();
              break;
          case MotionEvent.ACTION_MOVE:
              onMove();
              break;
          case MotionEvent.ACTION_CANCEL:
          case MotionEvent.ACTION_UP:
              endWrite();
              break;
      }
      postInvalidate();
      return true;
  }

  Path path = new Path();
  void startWrite(){
    
    
      removeCallbacks(writeDone);
      path.moveTo(cx, cy);
      //cBm.drawColor(Color.WHITE);
  }
  void onMove(){
    
    
      path.lineTo(cx, cy);
      cBm.drawColor(Color.BLACK);
      p.setStyle(Paint.Style.STROKE);
      p.setColor(Color.WHITE);
      p.setStrokeWidth(50);
      cBm.drawPath(path, p);
  }

  Runnable writeDone = new Runnable() {
    
    
      @Override
      public void run() {
    
    
          if(cb != null)cb.onWriteDone();
          path.reset();
      }
  };
  void endWrite(){
    
    
      removeCallbacks(writeDone);
      postDelayed(writeDone, 500);
  }

  @Override
  protected void onDraw(Canvas canvas) {
    
    
      if(bm != null && !bm.isRecycled())canvas.drawBitmap(bm, 0, 0, p);
  }

  Callback cb;
  public void setCallback(Callback c){
    
    
      cb = c;
  }
  public interface Callback{
    
    
      void onWriteDone();
  }
}

  • 分类器
  class Classifier{
    
    
     //mnist.tflite: 来自kotlin DEMO, 识别率很低, 文件小.
     //mnist_big.tflite: 来自JAVA DEMO, 识别率高,文件大
      final String MODEL = "mnist_big.tflite";
      //插入器
      Interpreter interpreter;
      //输入识别的图像尺寸
      int bmWidth, bmHeight;
      //用于读取tflite文件
      AssetManager asset;
      //用于创建ByteBuffer
      int modelInputSize;
      Classifier(AssetManager asset){
    
    
          this.asset = asset;
  		  //创建插入器.
          Interpreter.Options op = new Interpreter.Options();
          op.setUseNNAPI(true);
          interpreter = new Interpreter(loadModel(), op);
  		//获取输入信息
          int[] shape = interpreter.getInputTensor(0).shape();
          bmWidth = shape[1];
          bmHeight = shape[2];

  		//计算ByteBuffer大小.
          int FLOAT_TYPE_SIZE = 4;
          int PIXEL_SIZE = 1;
          modelInputSize = FLOAT_TYPE_SIZE * bmWidth * bmHeight * PIXEL_SIZE;

      }

  	  //加载模型
      ByteBuffer loadModel(){
    
    
          try {
    
    
              AssetFileDescriptor fd = asset.openFd(MODEL);
              FileInputStream is = new FileInputStream(fd.getFileDescriptor());
              FileChannel channel = is.getChannel();
              long startOffset = fd.getStartOffset();
              long declareLength = fd.getDeclaredLength();
              return channel.map(FileChannel.MapMode.READ_ONLY, startOffset, declareLength);
          } catch (IOException e) {
    
    
              e.printStackTrace();
          }
          return null;
      }

  	  //执行分类
      String classifier(Bitmap bm){
    
    
          //缩放图片到指定尺寸.(28*28)
          Bitmap nbm = Bitmap.createScaledBitmap(bm, bmWidth, bmHeight, true);
          ByteBuffer byteBuffer = convertBitmapToByteBuffer(nbm);

  		  //Kotlin中的代码:
  		  //  val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
  		  //  平时不用它, 看这代码头痛了好久.
  		  //若创建的数组不对, 如用float[10], 或float[2][10]
  		  //则会导致异常(Google 百度都不知道):
  		  /**  2020-09-10 10:55:18.213 14429-14429/com.ansondroider.digitclassifierbytfl E/AndroidRuntime: FATAL EXCEPTION: main
  Process: com.ansondroider.digitclassifierbytfl, PID: 14429
  java.lang.IllegalArgumentException: Cannot copy from a TensorFlowLite tensor (softmax_tensor) with shape [1, 10] to a Java object with shape [2, 10].
      at org.tensorflow.lite.Tensor.throwIfDstShapeIsIncompatible(Tensor.java:482)
      at org.tensorflow.lite.Tensor.copyTo(Tensor.java:252)
      at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:170)
      at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:347)
      at org.tensorflow.lite.Interpreter.run(Interpreter.java:306)
      at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$Classifier.classifier(DigitClassifierByTFL.java:98)
      at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$1.onWriteDone(DigitClassifierByTFL.java:33)
      at com.ansondroider.digitclassifierbytfl.PaintView$1.run(PaintView.java:86)
      at android.os.Handler.handleCallback(Handler.java:883)
      at android.os.Handler.dispatchMessage(Handler.java:100)
      at android.os.Looper.loop(Looper.java:214)
      at android.app.ActivityThread.main(ActivityThread.java:7356)
      at java.lang.reflect.Method.invoke(Native Method)
      at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492)
      at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930)**/
          float[][] result = new float[1][10];
          //执行.
          interpreter.run(byteBuffer, result);
          //格式化输出
          return getOutputSting(result);
      }

      String getOutputSting(float[][] floats){
    
    
      		//Kotlin 代码:
      		//  val maxIndex = output.indices.maxBy { output[it] } ?: -1
  			//  return "Prediction Result: %d\nConfidence: %2f".format(maxIndex, output[maxIndex])
		  //在float[10]数组中, 存放了推算的结果, 下标分别对应的是[0,9]的数字.
		  //只需要遍历10个数中, 找出最大的值即可.
          StringBuilder result = new StringBuilder("Result:\n");
          float[] res = floats[0];
          float max = -1;
          int v = -1;
          for(int i = 0; i < res.length; i ++){
    
    
              result.append("[" + i + "]=" + res[i]).append("\n");
              if(max < res[i]){
    
    
                  max = res[i];
                  v = i;
              }
          }
          result.append("BEST: " + v);
          return result.toString();
      }

      ByteBuffer convertBitmapToByteBuffer(Bitmap bm){
    
    
          //刚开始, 用错了函数接口: ByteBuffer.allocate
          //这样会导致推算的结果不管输入如何变化, 都输出固定的float[10]
          //在打开后一直不变, 而在调试过程中, 也出现过多次生新运行都显示同样的结果.
          ByteBuffer bf = ByteBuffer.allocateDirect(modelInputSize);
          bf.order(ByteOrder.nativeOrder());
          int[] pixels = new int[bmWidth * bmHeight];
          bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
          for(int i = 0; i < pixels.length; i ++){
    
    
              int r = (pixels[i] >> 16) & 0xFF;
              int g = (pixels[i] >> 8) & 0xFF;
              int b = pixels[i] & 0xFF;
              float normalizePixelValue = (r + g + b) / 3f / 255f;
              bf.putFloat(normalizePixelValue);
          }
          return bf;
      }

      void close(){
    
    
          interpreter.close();
      }
  }
}

在这里插入图片描述

相关

猜你喜欢

转载自blog.csdn.net/ansondroider/article/details/108508065