公式の推奨事項| TensorFlow Liteプラグインを使用してFlutterにテキスト分類を実装する

TensorFlowモデルをFlutterアプリケーションに統合するシンプルで効率的かつ柔軟な方法が必要な場合は、本日紹介した新しいプラグインtflite_flutterをお見逃しなくこのプラグインの開発者は、Google Summer of Code(GSoC)のインターンであるAmish Gargです。

tflite_flutterプラグインのコア機能:

  • プラグインはTFLite JavaやSwift APIと同様のDart APIを提供するため、その柔軟性とこれらのプラットフォームへの影響はまったく同じです。

  • プラグインはdart:ffiを介してTensorFlow Lite C APIに直接バインドされているため、他のプラットフォーム統合方法よりも効率的です。

  • 特定のプラットフォームコードを記述する必要はありません。

  • NNAPIを介してアクセラレーションサポートを提供し、AndroidでGPUデリゲートを使用し、iOSでメタルデリゲートを使用します。

このペーパーでは、tflite_flutterプラグインを体験するために、tflite_flutterを使用してテキスト分類Flutterアプリケーション作成します。まず、Flutterから新しいプロジェクトをtext_classification_app開始します。 

初期構成

LinuxおよびMacユーザー

されるinstall.shアプリケーションのルートディレクトリにコピーして、ルートディレクトリに実行しsh install.sh、この場合にはカタログです、text_classification_app/

Windowsユーザー

install.batファイルをアプリケーションのルートディレクトリにコピーし、バッチファイルをルートディレクトリinstall.bat(この例ではディレクトリ)で実行しますtext_classification_app/

GitHubリポジトリのリリースから 最新のバイナリリソースを自動的にダウンロードし、指定したディレクトリに配置します。

初期構成の詳細については、READMEファイルをクリックして表示してください

  • tflite_flutterのGitHubウェアハウスアドレス
    https://github.com/am15h/tflite_flutter_plugin

プラグインを入手する

pubspec.yaml追加tflite_flutter: ^<latest_version>

  • 最新バージョンについては、プラグインのリリースアドレスを参照して
    ください。https://pub.flutter-io.cn/packages/tflite_flutter

モデルをダウンロード

TensorFlowトレーニングモデルをモバイルで実行するには、この.tflite 形式を使用する必要があります。TensorFlowモデルをトレーニング.tflite形式に変換する方法を知る必要がある場合は公式ガイドを参照してください

ここでは、TensorFlow公式ウェブサイトで事前トレーニングされたテキスト分類モデルを使用します。モデルダウンロードリンクについては、このセクションのリンクアノテーションテキストをご覧ください。

事前学習済みモデルは、現在の段落の感情がポジティブかネガティブかを予測できます。Mass et alのLarge Movie Review Dataset v1.0データセットに基づいてトレーニングされてい  ます。データセットは、IMDBの映画レビューに基づいた正または負のタグで構成されます。詳細については、を参照してください

text_classification.tfliteそしてtext_classification_vocab.txtファイルがtext_classification_app /資産/ディレクトリにコピーされます。

pubspec.yaml追加ファイルassets/

assets:    
  - assets/

すべての準備が整ったので、コードの作成を開始できます。????

  • モデルコンバーター(コンバーター)Python APIガイド
    https://tensorflow.google.cn/lite/convert/python_api

  • 事前トレーニング済みのテキスト分類モデルのダウンロード(text_classification.tflite)
    https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification.tflite

  • データセットファイルのダウンロード(text_classification_vocab.txt)
    https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification_vocab.txt

分類子を実装する

前処理

記載されているテキスト分類モデルのページとして次の手順に従って、モデルを使用して段落を分類できます。

  1. 段落テキストをセグメント化し、事前定義された語彙セットを使用して、それを語彙IDのセットに変換します。

  2. 生成された単語IDをTensorFlow Liteモデルに入力します。

  3. 現在の段落が正であるか負であるかの確率値をモデルの出力から取得します。

まずtext_classification_vocab.txt、語彙を使用して、元の単語文字列にメソッドを記述します。

lib/フォルダに新しいファイル作成し  ますclassifier.dart

ここではtext_classification_vocab.txt、辞書に読み込まれるコードを最初に記述します

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';
  
  Map<String, int> _dict;

  Classifier() {
    _loadDictionary();
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }
  
}

△辞書を読み込む

次に、元の文字列をセグメント化する関数を記述しましょう。

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';

  // 单句的最大长度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;
  
  List<List<double>> tokenizeInputText(String text) {
    
    // 使用空格进行分词
    final toks = text.split(' ');
    
    // 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 的对应的字典值来填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 对于句子里的每个单词,在映射里找到相应的索引值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我们的解释器输入 tensor 所需的格式 [1, 256] 返回 List<List<double>>
    return [vec];
  }
}


△単語分割コード

分析にはtflite_flutterを使用します

これがこの記事の主要部分です。ここでは、tflite_flutterプラグインの目的について説明します。

ここでの分析とは、デバイスの入力データに基づくTensorFlow Liteモデルの処理を指します。TensorFlow Liteのモデル分析を使用するには、我々は必要な通訳を、それを実行するために、より多くを学ぶために

インタープリターの作成、モデルの読み込み

tflite_flutterは、リソースから直接インタープリターを作成する方法を提供します。

static Future<Interpreter> fromAsset(String assetName, {InterpreterOptions options})

モデルはassets/フォルダー内にあるため、上記の方法を使用してパーサーを作成する必要があります。InterpreterOptionsの手順について、こちらを参照してください

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名称
  final _modelFile = 'text_classification.tflite';

  // TensorFlow Lite 解释器对象
  Interpreter _interpreter;

  Classifier() {
    // 当分类器初始化以后加载模型
    _loadModel();
  }

  void _loadModel() async {
    
    // 使用 Interpreter.fromAsset 创建解释器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

}

△通訳コード作成

あなたがモデル入れたくない場合はassets/、ディレクトリを、tflite_flutterも、通訳を作成するために、植物のコンストラクタを提供してより多くの情報のため

分析を始めましょう!

次の方法で分析を開始します。

void run(Object input, Object output);

ここでのメソッドはJava APIのメソッドと同じであることに注意してください。

Object inputまたObject output、同じリストの入力テンソルと出力テンソルの次元でなければなりません。

入力テンソルと出力テンソルの次元を表示するには、次のコードを使用できます。

_interpreter.allocateTensors();
// 打印 input tensor 列表
print(_interpreter.getInputTensors());
// 打印 output tensor 列表
print(_interpreter.getOutputTensors());

この例のtext_classificationモデルの出力は次のとおりです。

InputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf280, name: embedding_input, type: TfLiteType.float32, shape: [1, 256], data:  1024]
OutputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf140, name: dense_1/Softmax, type: TfLiteType.float32, shape: [1, 2], data:  8]

次に、分類メソッドを実装します。これは、正の場合は1の値を返し、負の場合は0の戻り値を返します。

int classify(String rawText) {
    
    //  tokenizeInputText 返回形状为 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);
   
    // [1,2] 形状的输出
    var output = List<double>(2).reshape([1, 2]);
    
    // run 方法会运行分析并且存储输出的值
    _interpreter.run(input, output);

    var result = 0;
    // 如果输出中第一个元素的值比第二个大,那么句子就是消极的
    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

△分析コード

使用される一部の拡張機能は、tflite_flutterのListの拡張機能ListShapeで定義されています。

// 将提供的列表进行矩阵变形,输入参数为元素总数并保持相等 
// 用法:List(400).reshape([2,10,20]) 
// 返回 List<dynamic>

List reshape(List<int> shape)
// 返回列表的形状
List<int> get shape
// 返回列表任意形状的元素数量
int get computeNumElements

最終classifier.dartは次のようになります。

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名称
  final _modelFile = 'text_classification.tflite';
  final _vocabFile = 'text_classification_vocab.txt';

  // 语句的最大长度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;

  // TensorFlow Lite 解释器对象
  Interpreter _interpreter;

  Classifier() {
    // 当分类器初始化的时候加载模型
    _loadModel();
    _loadDictionary();
  }

  void _loadModel() async {
    // 使用 Intepreter.fromAsset 创建解析器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }

  int classify(String rawText) {
    // tokenizeInputText  返回形状为 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);

    //输出形状为 [1, 2] 的矩阵
    var output = List<double>(2).reshape([1, 2]);

    // run 方法会运行分析并且将结果存储在 output 中。
    _interpreter.run(input, output);

    var result = 0;
    // 如果第一个元素的输出比第二个大,那么当前语句是消极的
    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

  List<List<double>> tokenizeInputText(String text) {
    // 用空格分词
    final toks = text.split(' ');

    // 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 对应的字典值来填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 对于句子中的每个单词,在 dict 中找到相应的 index 值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我们的解释器输入 tensor 所需的形状 [1,256] 返回 List<List<double>>
    return [vec];
  }
}

これで、好みに応じてUIコードを実装でき、分類子の使用法は比較的簡単です。

// 创建 Classifier 对象
Classifer _classifier = Classifier();
// 将目标语句作为参数,调用 classify 方法
_classifier.classify("I liked the movie");
// 返回 1 (积极的)
_classifier.classify("I didn't liked the movie");
// 返回 0 (消极的)

ここで完全なコードを確認してください:UIを備えたテキスト分類のサンプルアプリ
https://github.com/am15h/tflite_flutter_plugin/tree/master/example/

△テキスト分類の適用例

tflite_flutterプラグインの詳細については、GitHubリポジトリのam15h / tflite_flutter_pluginにアクセスしてください

質疑応答

Q:tflite_flutterと  tflite v1.0.5 違いは何ですか?

tflite v1.0.5画像の分類、オブジェクトの検出など、特定のアプリケーションシナリオに高度な機能を提供することに重点を置いています。新しいtflite_flutterは、Java APIと同じ機能と柔軟性を提供し、任意のtfliteモデルで使用でき、デリゲートもサポートします。

dart:ffi(dart↔️(ffi)↔️C)を使用しているため、tflite_flutterは非常に高速です(低遅延)。そして、tfliteはプラットフォーム統合(dart↔️platform-channel↔️(Java / Swift)↔️JNI↔️C)を使用します。

Q:tflite_flutterを使用して画像分類アプリケーションを作成するにはどうすればよいですか?TensorFlow Lite Android Support Libraryに似た依存パッケージはありますか?

アップデート(07/01/2020):TFLite Flutter Helper開発ライブラリがリリースされました。

TensorFlow Lite Flutterヘルパーライブラリは、入出力TFLiteモデルを処理および制御するための使いやすいアーキテクチャを提供します。APIの設計とドキュメントは、TensorFlow Lite Androidサポートライブラリと同じです。詳細については、倉庫のTFLite Flutter Helper開発ライブラリGItHubアドレス参照してください

上記はこの記事の内容全体です。tflite_flutterプラグインに関するフィードバックを歓迎しますここでは、バグまたは機能のリクエストを報告してください

  • TFLite Flutter Helper開発ライブラリGItHubウェアハウスアドレス
    https://github.com/am15h/tflite_flutter_helper

  • tflite_flutterプラグインへの提案とフィードバック
    https://github.com/am15h/tflite_flutter_plugin/issues

ご清聴ありがとうございました。FlutterチームのMichael Thomsenに感謝します。

ありがとう

  • 翻訳者:Yuan、Gu Chuang字幕グループ

  • レビュアー:Xinlei、Lynn Wang、Alex、CFUGコミュニティ。

この記事は、TensorFlowオンラインディスカッションエリア(discuss.tf.wiki)、101.dev、Flutter中国語ドキュメント(flutter.cn)、およびFlutterコミュニティのオンラインチャネルで共同で公開されています。

TensorFlowとGoogle AIについて詳しく知る必要がある場合は、Googleの公式パブリックアカウントTensorFlow(TensorFlow_official)とGoogle Developers(Google_Developers)に従ってください

  • 記事のリンクを読むには、クリックして元のテキストまたは以下のURLを読んで
    ください。https://flutter.cn/community/tutorials/text-classification-using-tensorflow-lite-plugin-for-flutter

おすすめ

転載: blog.csdn.net/weixin_43459071/article/details/108633632