シンプルなデモでは、tensorflow2.x のカスタム keras モデルを tflite 形式に変換し、Android にデプロイします。

環境:

ウィンドウズ10

CUDA 10.1

非表示 7.6.4

テンソルフローGPU 2.1

アンドロイドスタジオ3.6

基本的には現時点では比較的新しい環境です。

tensorflow 2.0 以降、私は特に keras を使用してモデルをカスタマイズするのが好きなので、デプロイメント用にモデルを保存する方法を見つけたいと思っています。pb形式モデルのどのメソッドが保存されるのか、それとも全てのメソッドが保存されるのかがよく分からないので、当面はpbモデル展開の利用は検討しません。tflite ははるかに単純で、call メソッドの下にプロセスを保存するだけです。変換プロセスは非常に簡単で、トピックに直接進むだけです。

 

まず Python 部分:

1. 複数の入力と複数の出力を備えた単純なモデルをカスタマイズします。

class test_model2(tf.keras.Model):
    def __init__(self, name="test_model2"):
        super(test_model2, self).__init__(name=name)
        self.conv1 = tf.keras.layers.Conv2D(filters=1, kernel_size=2, kernel_initializer=tf.ones, name=self.name + "/conv1")

    @tf.function
    def call(self, inputs):
        output1 = self.conv1(inputs[0])
        output1 = tf.squeeze(output1)
        output1 = tf.reshape(output1, (1,))
        output2 = self.conv1(inputs[1])
        output2 = tf.squeeze(output2)
        output2 = tf.reshape(output2, (1,))
        return output1, output2
model = test_model2()
test_input1 = tf.ones((1, 2, 2, 1))
test_input2 = tf.zeros((1, 2, 2, 1))
input_list = [test_input1, test_input2]
test_output1, test_output2 = model(input_list)
print(test_output1)
print(test_output2)

実行後、印刷されます

tf.Tensor([4.], shape=(1,), dtype=float32)
tf.Tensor([0.], shape=(1,), dtype=float32)

これはかなりシンプルなモデルです。

2. 以下は tflite 形式への変換モデルです。

fit() 関数を使用する代わりにカスタム トレーニング ループを使用する場合は、モデルの入力サイズを手動で設定する必要があります。この例では、カスタマイズされたトレーニング プロセスとみなし、入力サイズを 1 回設定するだけです。主に形状が一致する場合、入力値はランダムにすることができます。呼び出し関数はデフォルトで変換されるため、トレーニングとテストが関数を通過しない場合は、トレーニング関数に呼び出しと同じ名前を付けないことをお勧めします。

test_input1 = tf.ones((1, 2, 2, 1))
test_input2 = tf.zeros((1, 2, 2, 1))
input_list = [test_input1, test_input2]
model._set_inputs(input_list)

最後にモデルを変換して保存します。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("./save/converted_model.tflite", "wb").write(tflite_model)

これで、保存フォルダーに保存された tflite ファイルが見つかります。

 

その後に AndroidStudio 部分が続きます。

1. 新しいプロジェクトを作成する

2. build.gradle を変更し、次の内容を追加します

android {
    ...
    defaultConfig {
        ...
        ndk {
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
        ...
    }
    aaptOptions {
        noCompress "tflite"
    }
    ...
}
dependencies {
    ...
    implementation 'org.tensorflow:tensorflow-lite:2.1.0'
}

3.tfliteファイルを入れて読み込む

app\src\main にアセット フォルダーを作成し、その中に Converted_model.tflite を置きます。このファイル パスは一意ではなく、読み取りの便宜のためにアセットに配置されます。

読み取りコードは次のようになります。

String MODEL_FILE = "converted_model.tflite";
Interpreter tfLite = null;
try {
    tfLite = new Interpreter(loadModelFile(getAssets(), MODEL_FILE));
}catch(IOException e){
    e.printStackTrace();
}

loadModelFile 関数は次のとおりです。

MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
            throws IOException {
        AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename);
        FileInputStream inputStream = new                 
        FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

4. 入力テンソルの作成

tflite の入力は ByteBuffer 形式です。

int net_input_sz = 2;

ByteBuffer inputData1;
inputData1 = ByteBuffer.allocateDirect(net_input_sz * net_input_sz * 4);//4表示一个浮点占4byte
inputData1.order(ByteOrder.nativeOrder());
inputData1.rewind();
inputData1.putFloat(1.0f);
inputData1.putFloat(1.0f);
inputData1.putFloat(1.0f);
inputData1.putFloat(1.0f);

ByteBuffer inputData2;
inputData2 = ByteBuffer.allocateDirect(net_input_sz * net_input_sz * 4);//4表示一个浮点占4byte
inputData2.order(ByteOrder.nativeOrder());
inputData2.rewind();
inputData2.putFloat(0.0f);
inputData2.putFloat(0.0f);
inputData2.putFloat(0.0f);
inputData2.putFloat(0.0f);

Object[] inputArray = {inputData1, inputData2};
 

ByteBuffer によって開かれるスペースのサイズは、ネットワーク入力サイズの合計に、精度によって占有されるバイト数を乗算したものです。たとえば、この例で設定されている入力形状は 1x2x2x1 であるため、4 となり、浮動小数点数は 4 バイトを占有します, つまり、サイズは4x4です。

5. 出力テンソルを構築する

float[] output1, output2;
output1 = new float[1];
output2 = new float[1];
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, output1);
outputMap.put(1, output2);

この例では、ネットワークの出力形状は [1,] であるため、ここでサイズ 1 の浮動小数点配列を直接構築できます。[2,3,4] などの 2 次元または 3 次元の出力がある場合は、多次元配列 new float[2][3][4] を構築する必要があります。ただし、私は多次元配列を構築する方法が好きではありません。処理のためにネイティブ層に渡すのが不便なので、通常は出力を reshape(output_tensor,[-1]) します。

6. 推論を実行して出力を印刷する

tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
Log.e("1111","output1:" + output1[0]);
Log.e("1111","output2:" + output2[0]);

一文で解決できます。次の出力を取得できます

2020-02-25 16:27:55.569 22585-22585/com.stars.tflite_test E/1111: output1:4.0
2020-02-25 16:27:55.569 22585-22585/com.stars.tflite_test E/1111: output2:0.0

 

この時点で、モデル全体の変換とデプロイが完了します。

ねえ、それは非常に単純ではありませんか?

 

Java パーツのパッケージ名を添付します。どれが必要なのか忘れてしまったので、すべて追加します。

package com.stars.tflite_test;

import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.os.Bundle;

import com.google.android.material.floatingactionbutton.FloatingActionButton;
import com.google.android.material.snackbar.Snackbar;

import androidx.appcompat.app.AppCompatActivity;
import androidx.appcompat.widget.Toolbar;

import android.util.Log;
import android.view.View;
import android.view.Menu;
import android.view.MenuItem;

import org.tensorflow.lite.Interpreter;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.Map;

遭遇した落とし穴:

1. 公式 BN レイヤーを tflite に移行できなかったので(理由はわかりませんが)、公式 BN レイヤーからあまり機能のない BN レイヤーに魔法のように改造して、tflite に変換することに成功しました。

2. 変換した tflite モデルを GPU 上で実行したときに得られる結果が正しくありません。解決方法がわかりません。github に質問をしましたが、まだ回答がありません。問題のアドレスはhttps://github.com/tensorflow/tensorflow/issues/38825です。それについて何か知っている場合は、教えてください。GPU で正しい結果を得るためにセッション モードで書かれたモデルを試しましたが、 2. x モードで書かれたモデルは動作しません。

 

おすすめ

転載: blog.csdn.net/qq_19313495/article/details/104498442