Official recommendation | Use TensorFlow Lite plugin to implement text classification in Flutter

If you want a simple, efficient and flexible way to integrate TensorFlow models into Flutter applications, then please don't miss the new plugin tflite_flutter that we introduced today . The developer of this plugin is Amish Garg, an intern at Google Summer of Code (GSoC).

The core features of tflite_flutter plugin:

  • The plug-in provides a Dart API similar to TFLite Java and Swift API, so its flexibility and the effect on these platforms are exactly the same;

  • The plug-in is directly bound to the TensorFlow Lite C API through dart:ffi, so it is more efficient than other platform integration methods;

  • No need to write specific platform code;

  • Provide acceleration support through NNAPI, use GPU Delegate on Android, and use Metal Delegate on iOS.

In this paper, we will use tflite_flutter build a text classification Flutter application with you to experience tflite_flutter plug. First, from a Flutter new project text_classification_appstart. 

Initial configuration

Linux and Mac users

Will be install.shcopied to the root directory of your application, and then perform in the root directory sh install.sh, in this case is the catalog text_classification_app/.

Windows users

Copy the install.bat file to the application root directory and run the batch file in the root directory install.bat, which is the directory in this example text_classification_app/.

It will automatically  download the latest binary resources from the releases of the GitHub repository and put it in the specified directory.

Click to view the README file for more information about the initial configuration .

  • tflite_flutter's GitHub warehouse address
    https://github.com/am15h/tflite_flutter_plugin

Get the plugin

In pubspec.yamladdtflite_flutter: ^<latest_version>

  • For the latest version, refer to the release address of the plugin
    https://pub.flutter-io.cn/packages/tflite_flutter

Download model

To run the TensorFlow training model on mobile, we need to use the .tflite format. If you need to know how to convert TensorFlow model for training .tfliteformat, please refer to the Official Guide .

Here we are going to use the pre-trained text classification model on the TensorFlow official website. For the model download link, see the link annotation text in this section.

The pre-trained model can predict whether the sentiment of the current paragraph is positive or negative. It is trained based on the Large Movie Review Dataset v1.0 dataset from Mass et al  . The data set consists of positive or negative tags based on IMDB movie reviews. See more information .

The text_classification.tfliteand text_classification_vocab.txtfiles are copied to text_classification_app / assets / directory.

In pubspec.yamladd file assets/.

assets:    
  - assets/

Now that everything is ready, we can start writing code. ????

  • Model converter (Converter) Python API guide
    https://tensorflow.google.cn/lite/convert/python_api

  • Pre-trained text classification model download (text_classification.tflite)
    https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification.tflite

  • Data set file download (text_classification_vocab.txt)
    https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification_vocab.txt

Implement the classifier

Pretreatment

As text classification model page where noted. You can use the model to classify paragraphs according to the following steps:

  1. Segment the paragraph text, and then use a predefined vocabulary set to convert it into a set of vocabulary IDs;

  2. Input the generated word ID into the TensorFlow Lite model;

  3. Get the probability value of whether the current paragraph is positive or negative from the output of the model.

We first write a method to the original word string, using text_classification_vocab.txta vocabulary.

lib/Create a new file in the  folder classifier.dart.

Here the first to write code that is loaded text_classification_vocab.txtinto the dictionary.

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';
  
  Map<String, int> _dict;

  Classifier() {
    _loadDictionary();
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }
  
}

△ Load dictionary

Now let's write a function to segment the original string.

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';

  // 单句的最大长度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;
  
  List<List<double>> tokenizeInputText(String text) {
    
    // 使用空格进行分词
    final toks = text.split(' ');
    
    // 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 的对应的字典值来填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 对于句子里的每个单词,在映射里找到相应的索引值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我们的解释器输入 tensor 所需的格式 [1, 256] 返回 List<List<double>>
    return [vec];
  }
}


△ word segmentation code

Use tflite_flutter for analysis

This is the main part of this article, here we will discuss the purpose of the tflite_flutter plugin.

The analysis here refers to the processing of the TensorFlow Lite model based on the input data on the device. To use TensorFlow Lite model analysis, we need an interpreter to run it, to learn more .

Create interpreter, load model

tflite_flutter provides a way to create an interpreter directly from resources.

static Future<Interpreter> fromAsset(String assetName, {InterpreterOptions options})

Since our model in the assets/folder, you need to create a parser using the above method. For instructions InterpreterOptions, please refer here .

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名称
  final _modelFile = 'text_classification.tflite';

  // TensorFlow Lite 解释器对象
  Interpreter _interpreter;

  Classifier() {
    // 当分类器初始化以后加载模型
    _loadModel();
  }

  void _loadModel() async {
    
    // 使用 Interpreter.fromAsset 创建解释器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

}

△ Create interpreter code

If you do not want to put the model assets/directory, tflite_flutter also provides a plant constructor to create the interpreter, for more information .

Let's start the analysis!

Now start the analysis with the following method:

void run(Object input, Object output);

Note that the methods here are the same as those in the Java API.

Object inputAnd Object outputmust be the Input Tensor and Output Tensor dimensions of the same list.

To view the dimensions of input tensor and output tensor, you can use the following code:

_interpreter.allocateTensors();
// 打印 input tensor 列表
print(_interpreter.getInputTensors());
// 打印 output tensor 列表
print(_interpreter.getOutputTensors());

The output of the text_classification model in this example is as follows:

InputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf280, name: embedding_input, type: TfLiteType.float32, shape: [1, 256], data:  1024]
OutputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf140, name: dense_1/Softmax, type: TfLiteType.float32, shape: [1, 2], data:  8]

Now, we implement the classification method, which returns a value of 1 for positive and a return value of 0 for negative.

int classify(String rawText) {
    
    //  tokenizeInputText 返回形状为 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);
   
    // [1,2] 形状的输出
    var output = List<double>(2).reshape([1, 2]);
    
    // run 方法会运行分析并且存储输出的值
    _interpreter.run(input, output);

    var result = 0;
    // 如果输出中第一个元素的值比第二个大,那么句子就是消极的
    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

△ Code for analysis

Some extensions used are defined under the extension ListShape on List of tflite_flutter:

// 将提供的列表进行矩阵变形,输入参数为元素总数并保持相等 
// 用法:List(400).reshape([2,10,20]) 
// 返回 List<dynamic>

List reshape(List<int> shape)
// 返回列表的形状
List<int> get shape
// 返回列表任意形状的元素数量
int get computeNumElements

The final classifier.dartshould look like this:

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名称
  final _modelFile = 'text_classification.tflite';
  final _vocabFile = 'text_classification_vocab.txt';

  // 语句的最大长度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;

  // TensorFlow Lite 解释器对象
  Interpreter _interpreter;

  Classifier() {
    // 当分类器初始化的时候加载模型
    _loadModel();
    _loadDictionary();
  }

  void _loadModel() async {
    // 使用 Intepreter.fromAsset 创建解析器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }

  int classify(String rawText) {
    // tokenizeInputText  返回形状为 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);

    //输出形状为 [1, 2] 的矩阵
    var output = List<double>(2).reshape([1, 2]);

    // run 方法会运行分析并且将结果存储在 output 中。
    _interpreter.run(input, output);

    var result = 0;
    // 如果第一个元素的输出比第二个大,那么当前语句是消极的
    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

  List<List<double>> tokenizeInputText(String text) {
    // 用空格分词
    final toks = text.split(' ');

    // 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 对应的字典值来填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 对于句子中的每个单词,在 dict 中找到相应的 index 值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我们的解释器输入 tensor 所需的形状 [1,256] 返回 List<List<double>>
    return [vec];
  }
}

Now, you can implement the UI code according to your preferences, and the usage of the classifier is relatively simple.

// 创建 Classifier 对象
Classifer _classifier = Classifier();
// 将目标语句作为参数,调用 classify 方法
_classifier.classify("I liked the movie");
// 返回 1 (积极的)
_classifier.classify("I didn't liked the movie");
// 返回 0 (消极的)

Please check the complete code here: Text Classification Example app with UI
https://github.com/am15h/tflite_flutter_plugin/tree/master/example/

△ Example application of text classification

To learn more about the tflite_flutter plugin, please visit the GitHub repo: am15h/tflite_flutter_plugin .

Q & A

Q: tflite_flutterand  tflite v1.0.5 What are the differences?

tflite v1.0.5It focuses on providing advanced features for specific application scenarios, such as image classification, object detection, and so on. The new tflite_flutter provides the same features and flexibility as the Java API, and can be used in any tflite model, and it also supports delegates.

Due to the use of dart:ffi (dart ↔️ (ffi) ↔️ C), tflite_flutter is very fast (with low latency). And tflite uses platform integration (dart ↔️ platform-channel ↔️ (Java/Swift) ↔️ JNI ↔️ C).

Q: How to use tflite_flutter to create a picture classification application? Are there any dependency packages similar to TensorFlow Lite Android Support Library?

Update (07/01/2020): TFLite Flutter Helper development library has been released.

TensorFlow Lite Flutter Helper Library provides an easy-to-use architecture for processing and controlling input and output TFLite models. Its API design and documentation are the same as TensorFlow Lite Android Support Library. For more information, please refer TFLite Flutter Helper development library GItHub address of the warehouse .

The above is the entire contents of this article, we welcome feedback on tflite_flutter plug-in, here to report bug or feature requests .

  • TFLite Flutter Helper development library GItHub warehouse address
    https://github.com/am15h/tflite_flutter_helper

  • Suggestions and feedback to the tflite_flutter plugin
    https://github.com/am15h/tflite_flutter_plugin/issues

Thanks for your attention, and thank you Michael Thomsen of the Flutter team.

Thanks

  • Translator: Yuan, Gu Chuang Subtitle Group

  • Reviewer: Xinlei, Lynn Wang, Alex, CFUG community.

This article is jointly published on the TensorFlow online discussion area (discuss.tf.wiki), 101.dev, Flutter Chinese document (flutter.cn), and online channels of the Flutter community.

If you need to know more about TensorFlow and Google AI, please follow Google's official public account TensorFlow (TensorFlow_official) and Google Developers (Google_Developers).

  • To read the link in the article, please click to read the original text or the URL below to view
    https://flutter.cn/community/tutorials/text-classification-using-tensorflow-lite-plugin-for-flutter

Guess you like

Origin blog.csdn.net/weixin_43459071/article/details/108633632