Android 機械学習モデルの軽量フレームワークである TensorFlow Lite

TensorFlow Lite の紹介

TensorFlow Lite は、モバイル デバイス、組み込みデバイス、IoT デバイスで機械学習モデルを実行するための軽量フレームワークです。TensorFlow をモバイル分野に拡張したものであり、携帯電話などのデバイスで機械学習を行うための限られたコンピューティング リソースの問題を解決することを目的としています。TensorFlow Lite は、モデルのサイズ、量子化を最適化し、特定のデバイス要件に合わせてカーネルを含めることにより、モデルを効率的に実行する機能を実現します。

TensorFlow Lite は、Java、C++、Python などを含む複数の言語の開発をサポートしています。TensorFlow モデルを Lite モデル形式に変換でき、開発者が使用できる豊富な API インターフェースを提供します。さらに、TensorFlow Lite はアクセラレータ ハードウェア (GPU、DSP など) の使用もサポートし、モデル推論の効率をさらに向上させます。

TensorFlow Lite には、スマート ホームでの音声認識、画像分類、オブジェクト検出、スマート ヘルスケアでの病気の診断と患者のモニタリング、自動運転での車両制御など、幅広いアプリケーション シナリオがあります。その高い効率性と移植性により、TensorFlow Lite は、携帯電話などの組み込みデバイスで機械学習を実行するための主流フレームワークの 1 つになりました。

TensorFlow Lite の公式ドキュメント アドレスはhttps://www.tensorflow.org/liteです。この Web サイトでは、TensorFlow Lite の使用ガイド、API ドキュメント、サンプル コード、およびモバイル デバイスと組み込みでの TensorFlow Lite の使用に関する情報を見つけることができます。システムなどに機械学習モデルをデプロイするためのベスト プラクティス。

TensorFlow Lite の統合

TensorFlow Lite を Android アプリケーションに統合するには、次の手順に従います。

  1. TensorFlow Lite ライブラリをアプリケーションの Gradle ビルド ファイルに追加します。build.gradle (Module: app) ファイルに次の依存関係を追加します。
dependencies {
    
    
    implementation 'org.tensorflow:tensorflow-lite:2.5.0'
}
  1. モデル ファイル (.tflite) をアプリケーションの "assets" ディレクトリにコピーします。

  2. モデルをアプリケーションにロードします。次のコードでモデルをロードします。

private Interpreter tflite;
tflite = new Interpreter(loadModelFile(), null);
    
private MappedByteBuffer loadModelFile() throws IOException {
    
    
    AssetFileDescriptor fileDescriptor = this.getAssets().openFd("model.tflite");
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
  1. TensorFlow Lite インタープリターを使用して推論を実行します。入力を準備して出力を取得する方法については、TensorFlow Lite のドキュメントを参照してください。

TensorFlow Lite セルフ トレーニング モデル

  1. まず、アプリケーションのニーズに適した機械学習モデルを選択してトレーニングする必要があります。モデルは、TensorFlow、PyTorch などの一般的な深層学習ライブラリを使用してトレーニングできます。

  2. トレーニング後、モデルを TensorFlow Lite プラットフォームでサポートされている形式に変換する必要があります。変換プロセス中にモデルを最適化したり、量子化などの手法によってモデルのサイズを縮小したりできるため、モデルがモバイル デバイスでの展開により適したものになります。TensorFlow が公式に提供する TFLite Converter または TensorFlow Hub を使用して、モデルの変換を完了できます。

  3. 変換が成功すると、TensorFlow Lite モデル ファイル (通常は .tflite ファイル) を取得できます。ファイルは、ローカル ディスクに保存するか、アプリケーションの assets ディレクトリに直接パッケージ化できます。

これらの手順が、TensorFlow Lite モデル ファイルを正常に取得して使用するのに役立つことを願っています。

TensorFlow Lite モデル ファイル

TensorFlow Lite モデル ファイルの Google 公式コレクションは、TensorFlow Hub Web サイトにあります。この Web サイトの検索バーに「TensorFlow Lite」などのキーワードを入力し、Enter キーを押すと、検索に関連するモデルを見つけることができます。

検索結果ページから、分類、オブジェクト検出、画像セグメンテーションなど、さまざまなタイプのモデルを参照してフィルター処理できます。各モデルには、モデルの使用方法とそのパフォーマンス メトリックに関する情報を含む、独自の紹介とドキュメントがあります。興味のあるモデルが見つかった場合は、リンクをクリックしてモデルの詳細ページに移動すると、ダウンロード可能な事前トレーニング済みの重みまたは変換された TensorFlow Lite モデル ファイルが提供される場合があります。

TensorFlow ハブのウェブサイトにアクセスしてください: https://tfhub.dev/

TensorFlow Lite の例

公式の TensorFlow GitHub リポジトリで、TensorFlow Lite を使用した Android の公式の例を見つけることができます。この例では、TensorFlow Lite を使用して画像内のオブジェクトを認識し、結果をアプリケーションに表示する方法を示します。

サンプルには、完全なプロジェクト コード、Gradle ファイル、モデル ファイルなどのリソースが含まれています。サンプル アプリケーションを直接ダウンロードして実行するか、リファレンスとして使用して独自の TensorFlow Lite Android アプリケーションを構築できます。

サンプル プロジェクトの GitHub ウェアハウス アドレスは次のとおりです:
https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android

以下は、公式の TensorFlow Lite モデル ファイルを使用したオブジェクトの検出と認識のサンプル コードです。

  1. TensorFlow Lite ライブラリをインポートする

    implementation 'org.tensorflow:tensorflow-lite:+'
    
  2. モデルファイルを読み込む

    private MappedByteBuffer loadModelFile(Activity activity, String modelPath) throws IOException {
          
          
        AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(modelPath);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }
    
  3. 前処理

    private Bitmap preprocess(Bitmap bitmap) {
          
          
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        int inputSize = 300;
    
        Matrix matrix = new Matrix();
        float scaleWidth = ((float) inputSize) / width;
        float scaleHeight = ((float) inputSize) / height;
        matrix.postScale(scaleWidth, scaleHeight);
    
        Bitmap resizedBitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, false);
    
        return resizedBitmap;
    }
    
  4. 推論を行う

    private void runInference(Bitmap bitmap) {
          
          
        try {
          
          
            // 加载模型文件
            MappedByteBuffer modelFile = loadModelFile(this, "detect.tflite");
    
            // 初始化解析器
            Interpreter.Options options = new Interpreter.Options();
            options.setNumThreads(4);
            Interpreter tflite = new Interpreter(modelFile, options);
    
            // 获取输入和输出 Tensor
            int[] inputs = tflite.getInputIds();
            int[] outputs = tflite.getOutputIds();
            int inputSize = tflite.getInputTensor(inputs[0]).shape()[1];
    
            // 进行预处理
            Bitmap resizedBitmap = preprocess(bitmap);
            ByteBuffer inputBuffer = convertBitmapToByteBuffer(resizedBitmap, inputSize);
    
            // 执行推理,并获取输出结果
            Object[] inputArray = {
          
          inputBuffer};
            Map<Integer, Object> outputMap = new HashMap<>();
            float[][][] locations = new float[1][100][4];
            float[][] classes = new float[1][100];
            float[][] scores = new float[1][100];
            float[] numDetections = new float[1];
            outputMap.put(outputs[0], locations);
            outputMap.put(outputs[1], classes);
            outputMap.put(outputs[2], scores);
            outputMap.put(outputs[3], numDetections);
            tflite.runForMultipleInputsOutputs(inputArray, outputMap);
    
            // 输出识别结果
            for (int i = 0; i < 100; ++i) {
          
          
                if (scores[0][i] > THRESHOLD) {
          
          
                    int id = (int) classes[0][i];
                    String label = labels[id + 1];
                    float score = scores[0][i];
                    RectF location = new RectF(
                            locations[0][i][1] * bitmap.getWidth(),
                            locations[0][i][0] * bitmap.getHeight(),
                            locations[0][i][3] * bitmap.getWidth(),
                            locations[0][i][2] * bitmap.getHeight()
                    );
                    Log.d(TAG, "Label: " + label + ", Confidence: " + score + ", Location: " + location);
                }
            }
    
            // 释放资源
            tflite.close();
        } catch (Exception e) {
          
          
            e.printStackTrace();
        }
    }
    
    private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap, int inputSize) {
          
          
        ByteBuffer byteBuffer = ByteBuffer.allocateDirect(inputSize * inputSize * 3);
        byteBuffer.order(ByteOrder.nativeOrder());
        Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, inputSize, inputSize, true);
        for (int y = 0; y < inputSize; ++y) {
          
          
            for (int x = 0; x < inputSize; ++x) {
          
          
                int pixelValue = resizedBitmap.getPixel(x, y);
                byteBuffer.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
                byteBuffer.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
                byteBuffer.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
            }
        }
        return byteBuffer;
    }
    

上記のコード例は、TensorFlow Lite が公式に提供するオブジェクト検出モデルに適用できます. 具体的なモデルの使用法と入出力 Tensor は、実際の状況に応じて調整できます.

おすすめ

転載: blog.csdn.net/weixin_44008788/article/details/130286827