Android TensorFlow Lite explora el clasificador digital (JAVA DEMO)

TensorFlow Lite

TensorFlow Lite es un marco de aprendizaje profundo de código abierto para la inferencia del lado del dispositivo
, que implementa modelos de aprendizaje automático en dispositivos móviles y dispositivos IoT.

alrededores

AndroidStudio 4.0 + JAVA

Clasificador digital

Clasifica dígitos escritos a mano mediante el modelo TensorFlow Lite.
Imagen de tflite-mnist-android

Acerca de DEMO

curso

No encontré el código fuente de la DEMO en la imagen al principio, así que lo trasplante de acuerdo con la DEMO de Kotlin. El siguiente es el proceso de migración. Si ya conoce TensorFlow Lite, omítalo usted mismo.

  1. Cree un nuevo módulo DigitClassifierByTFL en AS .
  • Entorno del compiladorInserte la descripción de la imagen aquí
  • Configuración relacionada con SDK en build.gradle:
android {
    compileSdkVersion 30
    buildToolsVersion "30.0.2"

    defaultConfig {
        applicationId "com.ansondroider.digitclassifierbytfl"
        minSdkVersion 16
        targetSdkVersion 16
        versionCode 1
        versionName "1.0"
    }
}
  • Estructura de directorios
    Inserte la descripción de la imagen aquí
  1. A la espera de que se complete la compilación, es necesario modificar algunas configuraciones:
  • build.gradle: No comprima el archivo .tflite , si no lo agrega, causará un error de ejecución debido a un problema con el modelo importado
    aaptOptions {
        noCompress "tflite"
    }
  • build.gradle: agregar dependencia de TensorFlow Lite
dependencies {
    implementation fileTree(dir: "libs", include: ["*.jar"])
    implementation ('org.tensorflow:tensorflow-lite:0.0.0-nightly'){changing = true}
}
  1. Código fuente y descripción
  • diseño
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent">
    <com.ansondroider.digitclassifierbytfl.PaintView
        android:layout_width="400dp"
        android:layout_height="400dp"
        android:layout_centerHorizontal="true"
        android:id="@+id/paintView"/>
    <TextView
        android:id="@+id/tvRes"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_alignParentBottom="true"
        android:textColor="#FF00FF00"
        android:text="Result: ?"
        android:textSize="30sp"/>

</RelativeLayout>

PaintView: Pintar con los dedos
TextView: Mostrar resultados.
Inserte la descripción de la imagen aquí

  • Actividad
package com.ansondroider.digitclassifierbytfl;

import android.Manifest;
import android.app.Activity;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.os.Build;
import android.os.Bundle;
import android.widget.TextView;

import org.tensorflow.lite.Interpreter;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;

public class DigitClassifierByTFL extends Activity {
    
    
  PaintView paintView;
  TextView tvRes;
  Classifier fier;
  @Override
  protected void onCreate(Bundle savedInstanceState) {
    
    
      super.onCreate(savedInstanceState);
      //初始化UI
      setContentView(R.layout.activity_digit_classifier_by_tfl);
      tvRes = (TextView)findViewById(R.id.tvRes);
      paintView = (PaintView)findViewById(R.id.paintView);
      //添加PaintView回调, 当手指绘画完成后, 立即调用分类器进行分类.
      //绘画完成: 当手指抬起后 500 毫秒.
      paintView.setCallback(new PaintView.Callback() {
    
    
          @Override
          public void onWriteDone() {
    
    
              final String res = fier.classifier(paintView.getBitmap());
              tvRes.post(new Runnable() {
    
    
                  @Override
                  public void run() {
    
    
                      tvRes.setText(res);
                  }
              });
          }
      });
      //创建分类器
      fier = new Classifier(getAssets());
      if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
    
    
          requestPermissions(new String[]{
    
    Manifest.permission.WRITE_EXTERNAL_STORAGE, Manifest.permission.READ_EXTERNAL_STORAGE}, 0x77);
      }
  }

  @Override
  protected void onDestroy() {
    
    
      super.onDestroy();
      fier.close();
  }
  • PaintView
package com.ansondroider.digitclassifierbytfl;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Path;
import android.util.AttributeSet;
import android.view.MotionEvent;
import android.view.View;

public class PaintView extends View {
    
    
  public PaintView(Context context) {
    
    
      super(context);
  }

  public PaintView(Context context, AttributeSet attrs) {
    
    
      super(context, attrs);
  }

  public PaintView(Context context, AttributeSet attrs, int defStyleAttr) {
    
    
      super(context, attrs, defStyleAttr);
  }

  Bitmap bm;
  public Bitmap getBitmap(){
    
    
      return bm;
  }

  @Override
  protected void onSizeChanged(int w, int h, int oldw, int oldh) {
    
    
      super.onSizeChanged(w, h, oldw, oldh);
      if(bm != null){
    
    
          bm.recycle();
      }

      bm = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
      cBm = new Canvas(bm);
      cBm.drawColor(Color.BLACK);
  }
  Canvas cBm;
  Paint p = new Paint(Paint.ANTI_ALIAS_FLAG);

  float dx, dy, cx, cy;
  @Override
  public boolean onTouchEvent(MotionEvent event) {
    
    
      cx = event.getX();
      cy = event.getY();
      switch(event.getAction()){
    
    
          case MotionEvent.ACTION_DOWN:
              dx = cx;
              dy = cy;
              startWrite();
              break;
          case MotionEvent.ACTION_MOVE:
              onMove();
              break;
          case MotionEvent.ACTION_CANCEL:
          case MotionEvent.ACTION_UP:
              endWrite();
              break;
      }
      postInvalidate();
      return true;
  }

  Path path = new Path();
  void startWrite(){
    
    
      removeCallbacks(writeDone);
      path.moveTo(cx, cy);
      //cBm.drawColor(Color.WHITE);
  }
  void onMove(){
    
    
      path.lineTo(cx, cy);
      cBm.drawColor(Color.BLACK);
      p.setStyle(Paint.Style.STROKE);
      p.setColor(Color.WHITE);
      p.setStrokeWidth(50);
      cBm.drawPath(path, p);
  }

  Runnable writeDone = new Runnable() {
    
    
      @Override
      public void run() {
    
    
          if(cb != null)cb.onWriteDone();
          path.reset();
      }
  };
  void endWrite(){
    
    
      removeCallbacks(writeDone);
      postDelayed(writeDone, 500);
  }

  @Override
  protected void onDraw(Canvas canvas) {
    
    
      if(bm != null && !bm.isRecycled())canvas.drawBitmap(bm, 0, 0, p);
  }

  Callback cb;
  public void setCallback(Callback c){
    
    
      cb = c;
  }
  public interface Callback{
    
    
      void onWriteDone();
  }
}

  • Clasificador
  class Classifier{
    
    
     //mnist.tflite: 来自kotlin DEMO, 识别率很低, 文件小.
     //mnist_big.tflite: 来自JAVA DEMO, 识别率高,文件大
      final String MODEL = "mnist_big.tflite";
      //插入器
      Interpreter interpreter;
      //输入识别的图像尺寸
      int bmWidth, bmHeight;
      //用于读取tflite文件
      AssetManager asset;
      //用于创建ByteBuffer
      int modelInputSize;
      Classifier(AssetManager asset){
    
    
          this.asset = asset;
  		  //创建插入器.
          Interpreter.Options op = new Interpreter.Options();
          op.setUseNNAPI(true);
          interpreter = new Interpreter(loadModel(), op);
  		//获取输入信息
          int[] shape = interpreter.getInputTensor(0).shape();
          bmWidth = shape[1];
          bmHeight = shape[2];

  		//计算ByteBuffer大小.
          int FLOAT_TYPE_SIZE = 4;
          int PIXEL_SIZE = 1;
          modelInputSize = FLOAT_TYPE_SIZE * bmWidth * bmHeight * PIXEL_SIZE;

      }

  	  //加载模型
      ByteBuffer loadModel(){
    
    
          try {
    
    
              AssetFileDescriptor fd = asset.openFd(MODEL);
              FileInputStream is = new FileInputStream(fd.getFileDescriptor());
              FileChannel channel = is.getChannel();
              long startOffset = fd.getStartOffset();
              long declareLength = fd.getDeclaredLength();
              return channel.map(FileChannel.MapMode.READ_ONLY, startOffset, declareLength);
          } catch (IOException e) {
    
    
              e.printStackTrace();
          }
          return null;
      }

  	  //执行分类
      String classifier(Bitmap bm){
    
    
          //缩放图片到指定尺寸.(28*28)
          Bitmap nbm = Bitmap.createScaledBitmap(bm, bmWidth, bmHeight, true);
          ByteBuffer byteBuffer = convertBitmapToByteBuffer(nbm);

  		  //Kotlin中的代码:
  		  //  val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
  		  //  平时不用它, 看这代码头痛了好久.
  		  //若创建的数组不对, 如用float[10], 或float[2][10]
  		  //则会导致异常(Google 百度都不知道):
  		  /**  2020-09-10 10:55:18.213 14429-14429/com.ansondroider.digitclassifierbytfl E/AndroidRuntime: FATAL EXCEPTION: main
  Process: com.ansondroider.digitclassifierbytfl, PID: 14429
  java.lang.IllegalArgumentException: Cannot copy from a TensorFlowLite tensor (softmax_tensor) with shape [1, 10] to a Java object with shape [2, 10].
      at org.tensorflow.lite.Tensor.throwIfDstShapeIsIncompatible(Tensor.java:482)
      at org.tensorflow.lite.Tensor.copyTo(Tensor.java:252)
      at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:170)
      at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:347)
      at org.tensorflow.lite.Interpreter.run(Interpreter.java:306)
      at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$Classifier.classifier(DigitClassifierByTFL.java:98)
      at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$1.onWriteDone(DigitClassifierByTFL.java:33)
      at com.ansondroider.digitclassifierbytfl.PaintView$1.run(PaintView.java:86)
      at android.os.Handler.handleCallback(Handler.java:883)
      at android.os.Handler.dispatchMessage(Handler.java:100)
      at android.os.Looper.loop(Looper.java:214)
      at android.app.ActivityThread.main(ActivityThread.java:7356)
      at java.lang.reflect.Method.invoke(Native Method)
      at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492)
      at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930)**/
          float[][] result = new float[1][10];
          //执行.
          interpreter.run(byteBuffer, result);
          //格式化输出
          return getOutputSting(result);
      }

      String getOutputSting(float[][] floats){
    
    
      		//Kotlin 代码:
      		//  val maxIndex = output.indices.maxBy { output[it] } ?: -1
  			//  return "Prediction Result: %d\nConfidence: %2f".format(maxIndex, output[maxIndex])
		  //在float[10]数组中, 存放了推算的结果, 下标分别对应的是[0,9]的数字.
		  //只需要遍历10个数中, 找出最大的值即可.
          StringBuilder result = new StringBuilder("Result:\n");
          float[] res = floats[0];
          float max = -1;
          int v = -1;
          for(int i = 0; i < res.length; i ++){
    
    
              result.append("[" + i + "]=" + res[i]).append("\n");
              if(max < res[i]){
    
    
                  max = res[i];
                  v = i;
              }
          }
          result.append("BEST: " + v);
          return result.toString();
      }

      ByteBuffer convertBitmapToByteBuffer(Bitmap bm){
    
    
          //刚开始, 用错了函数接口: ByteBuffer.allocate
          //这样会导致推算的结果不管输入如何变化, 都输出固定的float[10]
          //在打开后一直不变, 而在调试过程中, 也出现过多次生新运行都显示同样的结果.
          ByteBuffer bf = ByteBuffer.allocateDirect(modelInputSize);
          bf.order(ByteOrder.nativeOrder());
          int[] pixels = new int[bmWidth * bmHeight];
          bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
          for(int i = 0; i < pixels.length; i ++){
    
    
              int r = (pixels[i] >> 16) & 0xFF;
              int g = (pixels[i] >> 8) & 0xFF;
              int b = pixels[i] & 0xFF;
              float normalizePixelValue = (r + g + b) / 3f / 255f;
              bf.putFloat(normalizePixelValue);
          }
          return bf;
      }

      void close(){
    
    
          interpreter.close();
      }
  }
}

Inserte la descripción de la imagen aquí

Relacionado

Supongo que te gusta

Origin blog.csdn.net/ansondroider/article/details/108508065
Recomendado
Clasificación