TensorFlow Lite
TensorFlow Lite 是一种用于设备端推断的开源深度学习框架,
在移动设备和 IoT 设备上部署机器学习模型
环境
AndroidStudio 4.0 + JAVA
数字分类器
通过 TensorFlow Lite 模型对手写数字进行分类。
关于DEMO
- Github demo 源码(kotlin)
- 上面图片的DEMO Github demo 源码(Java)
历程
刚开始并没有找到图片中的DEMO源码, 于是自己根据kotlin的DEMO移植了一下, 以下是移植的过程, 若对TensorFlow Lite已有所了解, 请自行跳过.
- 在AS中新建Module DigitClassifierByTFL.
- 编译环境
- build.gradle中SDK相关配置:
android {
compileSdkVersion 30
buildToolsVersion "30.0.2"
defaultConfig {
applicationId "com.ansondroider.digitclassifierbytfl"
minSdkVersion 16
targetSdkVersion 16
versionCode 1
versionName "1.0"
}
}
- 目录结构
- 等待构建完成, 需修改一些配置:
- build.gradle: 不压缩 .tflite 文件, 若不加会因为导入模型有问题导致运行出错
aaptOptions {
noCompress "tflite"
}
- build.gradle: 增加 TensorFlow Lite依赖
dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"])
implementation ('org.tensorflow:tensorflow-lite:0.0.0-nightly'){changing = true}
}
- 源码及说明
- layout
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent">
<com.ansondroider.digitclassifierbytfl.PaintView
android:layout_width="400dp"
android:layout_height="400dp"
android:layout_centerHorizontal="true"
android:id="@+id/paintView"/>
<TextView
android:id="@+id/tvRes"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:textColor="#FF00FF00"
android:text="Result: ?"
android:textSize="30sp"/>
</RelativeLayout>
PaintView: 手指绘画.
TextView: 显示结果.
- Activity
package com.ansondroider.digitclassifierbytfl;
import android.Manifest;
import android.app.Activity;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.os.Build;
import android.os.Bundle;
import android.widget.TextView;
import org.tensorflow.lite.Interpreter;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
public class DigitClassifierByTFL extends Activity {
PaintView paintView;
TextView tvRes;
Classifier fier;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
//初始化UI
setContentView(R.layout.activity_digit_classifier_by_tfl);
tvRes = (TextView)findViewById(R.id.tvRes);
paintView = (PaintView)findViewById(R.id.paintView);
//添加PaintView回调, 当手指绘画完成后, 立即调用分类器进行分类.
//绘画完成: 当手指抬起后 500 毫秒.
paintView.setCallback(new PaintView.Callback() {
@Override
public void onWriteDone() {
final String res = fier.classifier(paintView.getBitmap());
tvRes.post(new Runnable() {
@Override
public void run() {
tvRes.setText(res);
}
});
}
});
//创建分类器
fier = new Classifier(getAssets());
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestPermissions(new String[]{
Manifest.permission.WRITE_EXTERNAL_STORAGE, Manifest.permission.READ_EXTERNAL_STORAGE}, 0x77);
}
}
@Override
protected void onDestroy() {
super.onDestroy();
fier.close();
}
- PaintView
package com.ansondroider.digitclassifierbytfl;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Path;
import android.util.AttributeSet;
import android.view.MotionEvent;
import android.view.View;
public class PaintView extends View {
public PaintView(Context context) {
super(context);
}
public PaintView(Context context, AttributeSet attrs) {
super(context, attrs);
}
public PaintView(Context context, AttributeSet attrs, int defStyleAttr) {
super(context, attrs, defStyleAttr);
}
Bitmap bm;
public Bitmap getBitmap(){
return bm;
}
@Override
protected void onSizeChanged(int w, int h, int oldw, int oldh) {
super.onSizeChanged(w, h, oldw, oldh);
if(bm != null){
bm.recycle();
}
bm = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
cBm = new Canvas(bm);
cBm.drawColor(Color.BLACK);
}
Canvas cBm;
Paint p = new Paint(Paint.ANTI_ALIAS_FLAG);
float dx, dy, cx, cy;
@Override
public boolean onTouchEvent(MotionEvent event) {
cx = event.getX();
cy = event.getY();
switch(event.getAction()){
case MotionEvent.ACTION_DOWN:
dx = cx;
dy = cy;
startWrite();
break;
case MotionEvent.ACTION_MOVE:
onMove();
break;
case MotionEvent.ACTION_CANCEL:
case MotionEvent.ACTION_UP:
endWrite();
break;
}
postInvalidate();
return true;
}
Path path = new Path();
void startWrite(){
removeCallbacks(writeDone);
path.moveTo(cx, cy);
//cBm.drawColor(Color.WHITE);
}
void onMove(){
path.lineTo(cx, cy);
cBm.drawColor(Color.BLACK);
p.setStyle(Paint.Style.STROKE);
p.setColor(Color.WHITE);
p.setStrokeWidth(50);
cBm.drawPath(path, p);
}
Runnable writeDone = new Runnable() {
@Override
public void run() {
if(cb != null)cb.onWriteDone();
path.reset();
}
};
void endWrite(){
removeCallbacks(writeDone);
postDelayed(writeDone, 500);
}
@Override
protected void onDraw(Canvas canvas) {
if(bm != null && !bm.isRecycled())canvas.drawBitmap(bm, 0, 0, p);
}
Callback cb;
public void setCallback(Callback c){
cb = c;
}
public interface Callback{
void onWriteDone();
}
}
- 分类器
class Classifier{
//mnist.tflite: 来自kotlin DEMO, 识别率很低, 文件小.
//mnist_big.tflite: 来自JAVA DEMO, 识别率高,文件大
final String MODEL = "mnist_big.tflite";
//插入器
Interpreter interpreter;
//输入识别的图像尺寸
int bmWidth, bmHeight;
//用于读取tflite文件
AssetManager asset;
//用于创建ByteBuffer
int modelInputSize;
Classifier(AssetManager asset){
this.asset = asset;
//创建插入器.
Interpreter.Options op = new Interpreter.Options();
op.setUseNNAPI(true);
interpreter = new Interpreter(loadModel(), op);
//获取输入信息
int[] shape = interpreter.getInputTensor(0).shape();
bmWidth = shape[1];
bmHeight = shape[2];
//计算ByteBuffer大小.
int FLOAT_TYPE_SIZE = 4;
int PIXEL_SIZE = 1;
modelInputSize = FLOAT_TYPE_SIZE * bmWidth * bmHeight * PIXEL_SIZE;
}
//加载模型
ByteBuffer loadModel(){
try {
AssetFileDescriptor fd = asset.openFd(MODEL);
FileInputStream is = new FileInputStream(fd.getFileDescriptor());
FileChannel channel = is.getChannel();
long startOffset = fd.getStartOffset();
long declareLength = fd.getDeclaredLength();
return channel.map(FileChannel.MapMode.READ_ONLY, startOffset, declareLength);
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
//执行分类
String classifier(Bitmap bm){
//缩放图片到指定尺寸.(28*28)
Bitmap nbm = Bitmap.createScaledBitmap(bm, bmWidth, bmHeight, true);
ByteBuffer byteBuffer = convertBitmapToByteBuffer(nbm);
//Kotlin中的代码:
// val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
// 平时不用它, 看这代码头痛了好久.
//若创建的数组不对, 如用float[10], 或float[2][10]
//则会导致异常(Google 百度都不知道):
/** 2020-09-10 10:55:18.213 14429-14429/com.ansondroider.digitclassifierbytfl E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.ansondroider.digitclassifierbytfl, PID: 14429
java.lang.IllegalArgumentException: Cannot copy from a TensorFlowLite tensor (softmax_tensor) with shape [1, 10] to a Java object with shape [2, 10].
at org.tensorflow.lite.Tensor.throwIfDstShapeIsIncompatible(Tensor.java:482)
at org.tensorflow.lite.Tensor.copyTo(Tensor.java:252)
at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:170)
at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:347)
at org.tensorflow.lite.Interpreter.run(Interpreter.java:306)
at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$Classifier.classifier(DigitClassifierByTFL.java:98)
at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$1.onWriteDone(DigitClassifierByTFL.java:33)
at com.ansondroider.digitclassifierbytfl.PaintView$1.run(PaintView.java:86)
at android.os.Handler.handleCallback(Handler.java:883)
at android.os.Handler.dispatchMessage(Handler.java:100)
at android.os.Looper.loop(Looper.java:214)
at android.app.ActivityThread.main(ActivityThread.java:7356)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930)**/
float[][] result = new float[1][10];
//执行.
interpreter.run(byteBuffer, result);
//格式化输出
return getOutputSting(result);
}
String getOutputSting(float[][] floats){
//Kotlin 代码:
// val maxIndex = output.indices.maxBy { output[it] } ?: -1
// return "Prediction Result: %d\nConfidence: %2f".format(maxIndex, output[maxIndex])
//在float[10]数组中, 存放了推算的结果, 下标分别对应的是[0,9]的数字.
//只需要遍历10个数中, 找出最大的值即可.
StringBuilder result = new StringBuilder("Result:\n");
float[] res = floats[0];
float max = -1;
int v = -1;
for(int i = 0; i < res.length; i ++){
result.append("[" + i + "]=" + res[i]).append("\n");
if(max < res[i]){
max = res[i];
v = i;
}
}
result.append("BEST: " + v);
return result.toString();
}
ByteBuffer convertBitmapToByteBuffer(Bitmap bm){
//刚开始, 用错了函数接口: ByteBuffer.allocate
//这样会导致推算的结果不管输入如何变化, 都输出固定的float[10]
//在打开后一直不变, 而在调试过程中, 也出现过多次生新运行都显示同样的结果.
ByteBuffer bf = ByteBuffer.allocateDirect(modelInputSize);
bf.order(ByteOrder.nativeOrder());
int[] pixels = new int[bmWidth * bmHeight];
bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
for(int i = 0; i < pixels.length; i ++){
int r = (pixels[i] >> 16) & 0xFF;
int g = (pixels[i] >> 8) & 0xFF;
int b = pixels[i] & 0xFF;
float normalizePixelValue = (r + g + b) / 3f / 255f;
bf.putFloat(normalizePixelValue);
}
return bf;
}
void close(){
interpreter.close();
}
}
}