Edge Computing: Detailed explanation of Android edge online training on device training (device-side training) based on tflite

This article is an original post. It mainly refers to the official website On-Device Training with TensorFlow Lite, Fasion Minist personalized training code and Muirush code. If you want to reprint, be sure to include this link. Violators will be prosecuted! !

This post is sure to be a hit! !

on device training English official website

https://www.tensorflow.org/lite/examples/on_device_training/overview

on device training Chinese official website

https://www.tensorflow.org/lite/examples/on_device_training/overview?hl=zh-cn

At present, there are few cases of incremental training based on tflite end-side. Currently, there are only clothing recognition cases on the official website. Please refer to the official website:

https://www.tensorflow.org/lite/examples/on_device_training/overview?hl=zh-cn

Google official website on device training example--Fasion Mnist Android training and inference

examples/lite/examples/model_personalization at master · tensorflow/examples · GitHub

Muirush linear regression prediction code

GitHub - Muirush/Model-training-with-Tensorflow-tfLite-and-android

But the main problem is: this code is an example of image classification training inference, coupled with more complex codes, it does not use the code to grasp the core of tflite inference, and the script that defines the signature is not seen, so it is difficult for beginners It is quite difficult, and the whole process of how to train inference from scratch to end-side is not explained clearly.

Therefore, this article will start with cloud training, use DNN to achieve regression prediction of y=2*x – 1, convert the model into a tflite model, and use the latest signature function to achieve incremental training and inference on the device side.

Software version: tensorflow 2.8 (device-side inference is a function available after 2.7)

Android Studio:4.2.1

Step one: Cloud training, writing signature function

Note that the difference here from previous cloud training is to write a signature function, which can be used for inference and training when the model is converted to tflite. The code is as follows: 

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers

from tensorflow import initializers as init

from tensorflow import losses

from tensorflow.keras import optimizers

from tensorflow import data as tfdata

from tensorflow import losses

from tensorflow.keras import optimizers

import numpy as np



class Model(tf.Module):

    def __init__(self):

#  定义2层全连接网络,输入维度input_dim是1,第一隐层是10个神经元,第二层也是10个神经元,输出层是1个

        self.model = tf.keras.Sequential()

        self.model.add(tf.keras.layers.Dense(units=10, input_dim=1))

        self.model.add(tf.keras.layers.Dense(units=10, ))

        self.model.add(tf.keras.layers.Dense(units=1))

       

        self.model.compile(loss=tf.keras.losses.MSE,

                           optimizer=tf.keras.optimizers.SGD(learning_rate=1e-5))



#   此处是非常重要的定义签名函数,尤其注意输入输出维度,且输入转化为tensor

    @tf.function(input_signature=[

        tf.TensorSpec([1, 1], tf.float32),

        tf.TensorSpec([1], tf.float32),

    ])

   

#   此处特别注意,x y尽管是形参,输入变量,但是后期在安卓中训练时必须保持一致,否则会报错

#   训练代码

    def train(self, x, y):

        with tf.GradientTape() as tape:

            prediction = self.model(x)

            loss = self.model.loss(y, prediction)

        gradients = tape.gradient(loss, self.model.trainable_variables)

        self.model.optimizer.apply_gradients(

            zip(gradients, self.model.trainable_variables))

        result = {"loss": loss}

        return result



#   推理代码

    @tf.function(input_signature=[

        tf.TensorSpec([1], tf.float32),

    ])

    def infer(self, x):

        pred =self.model(x)

        return {

            "output": pred

        }



#   保存在安卓端训练后的新权重

    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])

    def save(self, checkpoint_path):

        tensor_names = [weight.name for weight in self.model.weights]

        tensors_to_save = [weight.read_value() for weight in self.model.weights]

        tf.raw_ops.Save(

            filename=checkpoint_path, tensor_names=tensor_names,

            data=tensors_to_save, name='save')

        return {

            "checkpoint_path": checkpoint_path

        }



#   加载在安卓端训练后的新权重,用于新数据做推理

    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])

    def restore(self, checkpoint_path):

        restored_tensors = {}

        for var in self.model.weights:

            restored = tf.raw_ops.Restore(

                file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,

                name='restore')

            var.assign(restored)

            restored_tensors[var.name] = restored

        return restored_tensors


NUM_EPOCHS = 10000

BATCH_SIZE = 1

epochs = np.arange(1, NUM_EPOCHS + 1, 1)

losses = np.zeros([NUM_EPOCHS])

m = Model()


# 输入数据构造

x1 = np.array([[-1.0],[0.0],[1.0],[2.0],[3.0],[4.0], [5.0],[6.0],[7.0],[8.0],[9.0]], dtype = float)

y1 = np.array([-3.0,-1.0,1.0,3.0,5.0,7.0,9.0,11.0,13.0,15.0,17.0], dtype = float)


# array转化为tensor

features = tf.convert_to_tensor(x1, dtype=float)

labels = tf.convert_to_tensor(y1, dtype=float)


# 构造batch

train_ds = tf.data.Dataset.from_tensor_slices((features, labels))

train_ds = train_ds.batch(BATCH_SIZE)



# 训练

for i in range(NUM_EPOCHS):

    for x, y in train_ds:

        result = m.train(x, y)

    losses[i] = result['loss']

    if (i + 1) % 100 == 0:

        print('epochs=', i + 1, 'loss=', losses[i])

The training results are shown below:

epochs= 100 loss= 0.21976947784423828
epochs= 200 loss= 0.1585017591714859
epochs= 300 loss= 0.1464373618364334
epochs= 400 loss= 0.13536646962165833
epochs= 500 loss= 0.12510548532009125
epochs= 600 loss= 0.11560399830341339
epochs= 700 loss= 0.10680033266544342
epochs= 800 loss= 0.0986374095082283
……
epochs= 9500 loss= 4.569278098642826e-05
epochs= 9600 loss= 4.153713598498143e-05
epochs= 9700 loss= 3.7766891182400286e-05
epochs= 9800 loss= 3.464591281954199e-05
epochs= 9900 loss= 3.1359726563096046e-05
epochs= 10000 loss= 2.897171361837536e-05

The second step is to save the model and convert it into a tflite model.

# Save the model. Note that here is the key code to save the signature function, otherwise it will be in the subsequent generated code.

SAVED_MODEL_DIR = "saved_model"


tf.saved_model.save(

    m,

    SAVED_MODEL_DIR,

    signatures={

        'train':

            m.train.get_concrete_function(),

        'infer':

            m.infer.get_concrete_function(),

        'save':

            m.save.get_concrete_function(),

        'restore':

            m.restore.get_concrete_function(),

})



# Convert the model

# 保存模型

converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)

converter.target_spec.supported_ops = [

    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.

    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.

]

converter.experimental_enable_resource_variables = True

# 将云端模型转化为tflite模型,只有转化为tflite,安卓端才可以进行推理

tflite_model = converter.convert()

open('linear_model_0921.tflite', 'wb').write(tflite_model)

The output is:

INFO:tensorflow:Assets written to: saved_model/assets
21168

Step 3: Check whether the signature function is successfully built and its input and output

This step is the key to later Android development.

# 查看签名函数

# Print the signatures from the converted model

interpreter = tf.lite.Interpreter('linear_model_0921.tflite')

signatures = interpreter.get_signature_list()

print(signatures)

The output is:

​​​​​​​{'infer': {'inputs': ['x'], 'outputs': ['output']}, 'restore': {'inputs': 
['checkpoint_path'], 'outputs': ['dense_6/bias:0', 'dense_6/kernel:0', 'dense_7/bias:0', 
'dense_7/kernel:0', 'dense_8/bias:0', 'dense_8/kernel:0']}, 'save': {'inputs': 
['checkpoint_path'], 'outputs': ['checkpoint_path']}, 'train': {'inputs': ['x', 'y'], 
'outputs': ['loss']}}
​​

Step 4: Use python to perform inference on tflite (cloud operation)

The purpose of this step is to verify whether the accuracy of the converted tflite model has declined.

interpreter = tf.lite.Interpreter('linear_model_0921.tflite')

interpreter.allocate_tensors()

infer = interpreter.get_signature_runner("infer")

x6 = np.array([13.0], dtype = float)

x7 = tf.convert_to_tensor(x6, dtype=float)

infer(x=x7)['output'][0]

The output is:

array([24.985922], dtype=float32)

It is important to note that the above is the inference result of calling the tflite model on the cloud! !

//The following step is the inference of the cloud model, starting with m., you should pay attention to understand the difference, that is, tensorflow saved model

result = m.infer(x=x7)['output']

np.array(result)[0]

The output is:

array([24.98592], dtype=float32)

It can be seen that the results of the two are the same, indicating that the accuracy of the model has not decreased after converting from the cloud large model saved model format to tflite.

The fifth step is to use tflite to train in the cloud. This shows that the tflite model can still be trained using the python interface after the previous step of cloud training results. The success of this step shows that it can also be trained using the java interface on the Android side.

train = interpreter.get_signature_runner("train")

# NUM_EPOCHS = 50

# BATCH_SIZE = 100

more_epochs = np.arange(41, 501, 1)

more_losses = np.zeros([400])


BATCH_SIZE1 = 1

for i in range(400):

    for x, y in train_ds:

        result = train(x=x, y=y)

    more_losses[i] = result['loss']

    if (i + 1) % 2 == 0:

        print('epochs=', i + 1, 'more_losses=', more_losses[i])

I feel that this is a bit obscure. Let me explain it with a picture. Here is the result of using the tensorflow large model in the cloud to train 40 epochs (blue part), and using the tflite model to run 400 epochs in the cloud (orange). It can be seen from the curve that tflite is trained on the cloud, which exactly illustrates the characteristics of transfer learning.

**************************************************************************************************************

The following is the code for Android

Step 6: Android edge training and inference

This example of the Android interface uses Muirush code. The code is as follows. This code can only be used for Android-side inference and cannot be used for Android-side training. Therefore, the tflite model generated with the code of Model Training 1.py can only be used. The interpreter.run(input,output) method is used for inference and training. If training does not work, you must generate new tflite code according to the above code, so that the interpreter can use the latest method: runSignature

Edge-side inference: interpreter.runSignature(inputs, outputs, "infer");

Edge-side training: interpreter.runSignature(inputs, outputs, "train")

GitHub - Muirush/Model-training-with-Tensorflow-tfLite-and-android

Use Android Studio to open this project:

Mainly modify two parts:

Place the tflite file generated at the beginning of the article into the assets folder;

Modify MainActivity.java, the code is as follows:

Comment out the original reasoning:

//    public float doInference(String val){

//        float [] input = new float[1];

//        input [0] = Float.parseFloat(val);

//

//        float [][] output = new float[1][1];

//        interpreter.run(input,output);

//        return output[0][0];

//    }

Add new inference and training methods

// infer uses the latest runsignature method, signature   

float doInference(float val[][]) {

        // Run the inference.

        FloatBuffer testImages = FloatBuffer.wrap(val[0]);

        float[] output = new float[1];

        FloatBuffer output2 = FloatBuffer.wrap(output);

        Map<String, Object> inputs = new HashMap<>();

        inputs.put("x", testImages.rewind());

        Map<String, Object> outputs = new HashMap<>();

        outputs.put("output", output2);

        interpreter.runSignature(inputs, outputs, "infer");

        return output[0];

    }



    float doTrain(float val[][]) {

        // Run the training.

        float[][] var = new float[1][1];

        var[0][0] = 3.5f;

        float[] var2 = new float[1];

        var2[0] = 6.0f;

        FloatBuffer testImages = FloatBuffer.wrap(var[0]);

        float[] loss1 = new float[1];

        FloatBuffer label2 = FloatBuffer.wrap(var2);

        FloatBuffer loss2 = FloatBuffer.wrap(loss1);

        Map<String, Object> inputs = new HashMap<>();

        inputs.put("x", testImages.rewind());

        inputs.put("y", label2.rewind());

        Map<String, Object> outputs = new HashMap<>();

        outputs.put("loss", loss2);

        interpreter.runSignature(inputs, outputs, "train");

        return loss1[0];

    }

Modify onclick method:

public void onClick(View v) {

//                float f = doInference(ed.getText().toString());

                String var = ed.getText().toString();

                float [][] var2 = new float[1][1];

                var2[0][0] = Float.parseFloat(var);

//                推理

//                float f = doInference(var2);

//                tv.setText(("Value of Y: "+ f));

//                训练

                float loss6 = doTrain(var2);

                tv.setText(("Loss is: "+ loss6));

            }

When performing training, click Run app:

The simulated running interface of the mobile phone will appear. Note that there are a few points to explain. The cloud model was trained for 10,000 epochs. When training on the Android edge, the loss was 4.5*1E-5, indicating that the training was based on the cloud training. , the loss then decreases. I only wrote one value here during training, just for convenience. There are a few more here. They are the same when written as epoch, and there is no essential difference:

When executing inference, the Android simulator interface displays as follows, indicating that the cloud inference results, cloud tflite inference results, and Android tflite inference results are consistent. So far, it has been successful:

Appendix: The complete code of the modified MainActivity.java is as follows:

package com.desertlocust.tfmodel1;

import androidx.appcompat.app.AppCompatActivity;



import android.content.res.AssetFileDescriptor;

import android.os.Bundle;

import android.view.View;

import android.widget.Button;

import android.widget.EditText;

import android.widget.TextView;



import org.tensorflow.lite.Interpreter;



import java.io.FileInputStream;

import java.io.IOException;

import java.nio.MappedByteBuffer;

import java.nio.channels.FileChannel;

import java.util.HashMap;

import java.util.Map;

import java.nio.FloatBuffer;



public class MainActivity extends AppCompatActivity {

    private EditText ed;

    private TextView tv;

    private Button bt;

    private Interpreter interpreter;



    @Override

    protected void onCreate(Bundle savedInstanceState) {

        super.onCreate(savedInstanceState);

        setContentView(R.layout.activity_main);

        ed = findViewById(R.id.input);

        tv = findViewById(R.id.output);

        bt = findViewById(R.id.submit);



        try {

            interpreter = new Interpreter(loadModelFile(),null);

        }catch (IOException e){

            e.printStackTrace();

        }



        bt.setOnClickListener(new View.OnClickListener() {

            @Override

            public void onClick(View v) {

//                float f = doInference(ed.getText().toString());

                String var = ed.getText().toString();

                float [][] var2 = new float[1][1];

                var2[0][0] = Float.parseFloat(var);

//                推理

//                float f = doInference(var2);

//                tv.setText(("Value of Y: "+ f));

//                训练

                float loss6 = doTrain(var2);

                tv.setText(("Loss is: "+ loss6));

            }

        });

    }

//    加载tflite模型

    private MappedByteBuffer loadModelFile() throws IOException{

        AssetFileDescriptor assetFileDescriptor = this.getAssets().openFd("linear_model_0921.tflite");

        FileInputStream fileInputStream = new FileInputStream(assetFileDescriptor.getFileDescriptor());

        FileChannel fileChannel = fileInputStream.getChannel();

        long startOffset = assetFileDescriptor.getStartOffset();

        long  length = assetFileDescriptor.getLength();

        return  fileChannel.map(FileChannel.MapMode.READ_ONLY,startOffset,length);

    }



//    infer 采用run方法

//    public float doInference(String val){

//        float [] input = new float[1];

//        input [0] = Float.parseFloat(val);

//

//        float [][] output = new float[1][1];

//        interpreter.run(input,output);

//        return output[0][0];

//    }



//    infer 采用最新的runsignature方法,签名

    float doInference(float val[][]) {

        // Run the inference.

        FloatBuffer testImages = FloatBuffer.wrap(val[0]);

        float[] output = new float[1];

        FloatBuffer output2 = FloatBuffer.wrap(output);

        Map<String, Object> inputs = new HashMap<>();

        inputs.put("x", testImages.rewind());

        Map<String, Object> outputs = new HashMap<>();

        outputs.put("output", output2);

        interpreter.runSignature(inputs, outputs, "infer");

        return output[0];

    }

    float doTrain(float val[][]) {



        // Run the training.

        float[][] var = new float[1][1];

        var[0][0] = 3.5f;

        float[] var2 = new float[1];

        var2[0] = 6.0f;

        FloatBuffer testImages = FloatBuffer.wrap(var[0]);

        float[] loss1 = new float[1];

        FloatBuffer label2 = FloatBuffer.wrap(var2);

        FloatBuffer loss2 = FloatBuffer.wrap(loss1);

        Map<String, Object> inputs = new HashMap<>();

        inputs.put("x", testImages.rewind());

        inputs.put("y", label2.rewind());

        Map<String, Object> outputs = new HashMap<>();

        outputs.put("loss", loss2);

        interpreter.runSignature(inputs, outputs, "train");

        return loss1[0];

    }

}

Finally, if you want to use Android Studio to create an Android phone emulator and run the script, please refer to my previous article or other information on the Internet for more details.

If you follow the operation and see this, it means that you have already used tflite to perform incremental training and inference on the Android edge. You can follow this step to complete your own complex tasks.

Written at the end, let me talk about my thoughts. Currently, a large number of codes on the Internet are mainly based on tflite edge-side reasoning, and the old run method is used as an example. The cases given on the Internet couple a lot of code of the image, which is not easy to understand. Through this example, you can quickly get the essence of tflite, and finally facilitate communication. We have specially created a tlfite group. Welcome to join and let us communicate and make progress together. Thank you.

Guess you like

Origin blog.csdn.net/qq_18256855/article/details/127028071