Developing deep learning applications based on Tensorflow Lite framework for Android (1)

Android deployment tensorflow lite

Follow the instructions on the official website and add the following configuration to the configuration in the project's module build file build.gradle:

	implementation 'org.tensorflow:tensorflow-lite:2.7.0'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'
    implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'

android{
   aaptOptions {
        noCompress "tflite"
    }
  defaultConfig {
        ndk {
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
    }
 }

Import model resource resources

Create the model model.tflite created in the article " About converting Tesorflow's SavedModel model into a tflite model " and import it into the assets directory of the Android project.

Define the model basic configuration class BaseModelConfig

/**
 * 定义模型的基本配置类
 */
public abstract class BaseModelConfig{
    //每通道处理的字节数
    var numBytesPerChannel:Int = 0
    //定义批处理的个数
    var dimBatchSize:Int = 0
    //定义像素个数
    var dimPixelSize:Int = 0
    //定义图片的宽度
    var dimImgWidth:Int = 0
    //定义图片的高度
    var dimImgHeight:Int = 0
    //定义平均差
    var imageMean=0
    //定义图片的标准差
    var imageSTD:Float = 0.0F
    //定义模型的名称
    lateinit var modelName:String

    constructor() : super() {
        setConfigs()
    }
    /**
     * 将像素值转换成ByteBuffer
     * 增加图片的值
     */
    public abstract fun addImgValue(buffer: ByteBuffer,pixel:Int)

    /**
     * 配置
     */
    public abstract fun setConfigs()
}

Define the FloatSavedModelConfig class

class FloatSavedModelConfig: BaseModelConfig() {
    public override fun setConfigs() {
        modelName="model.tflite"
        numBytesPerChannel = 4
        dimBatchSize = 1
        dimPixelSize = 1
        dimImgWidth = 28
        dimImgHeight = 28
        imageMean = 0
        imageSTD = 255.0f
    }

    override fun addImgValue(imgData: ByteBuffer, pixel: Int) {
        imgData.putFloat(((pixel  and 0xFF) - imageMean) / imageSTD)
    }
}

Create a factory class that configures model parameters

object ModelConfigFactory {
    const val FLOAT_SAVED_MODEL = "float_saved_model"
    const val QUANT_SAVED_MODEL = "quant_saved_model"

    fun getModelConfig(model: String): BaseModelConfig? =
        when(model) {
            FLOAT_SAVED_MODEL-> FloatSavedModelConfig()
            QUANT_SAVED_MODEL-> QuantSavedModelConfig()
            else->null
        }
}

Define image classifier

class ImageClassifier {
    private val TAG = "FashionMNIST"
    private val RESULTS_TO_SHOW = 3

    lateinit var mTFLite: Interpreter

    lateinit var mModelPath:String
    var mNumBytesPerChannel = 0

    var mDimBatchSize = 0
    var mDimPixelSize = 0

    var mDimImgWidth = 0
    var mDimImgHeight = 0

    lateinit var mModelConfig:BaseModelConfig

    //定义标签检测的二维数组1x10
    val mLabelProbArray = Array(1) {
        FloatArray(
            10
        )
    }
    val labels = arrayListOf("T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子")

    //定义检测结果保持到优先队列中
    var mSortedLabels = PriorityQueue<Map.Entry<String, Float>>(
                        RESULTS_TO_SHOW) {
            o1, o2 -> o1?.value!!.compareTo(o2?.value!!)
    }

    /**
     * 配置参数
     */
    private fun initConfig(config: BaseModelConfig) {
        mModelConfig = config
        mNumBytesPerChannel = config.numBytesPerChannel
        mDimBatchSize = config.dimBatchSize
        mDimPixelSize = config.dimPixelSize
        mDimImgWidth = config.dimImgWidth
        mDimImgHeight = config.dimImgHeight
        mModelPath = config.modelName
    }

    constructor(modelConfig: String, activity: Activity) {
        // 初始化分类器的相关参数
        initConfig(ModelConfigFactory.getModelConfig(modelConfig)!!)

        // 使用配置参数初始化翻译器
        mTFLite = Interpreter(loadModelFile(activity)!!)
    }

    /**
     * 在Assets中的模型文件映射到内存中
     * */
    private fun loadModelFile(activity: Activity): MappedByteBuffer? {
        val fileDescriptor = activity.assets.openFd(mModelPath)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

    /**
     * 将图片数据写入到ByteBuffer,加载到内存中
     * */
    protected fun convertBitmapToByteBuffer(bitmap: Bitmap?): ByteBuffer {
        val intValues = IntArray(mDimImgWidth * mDimImgHeight)
        //调整要处理的图片为28x28
        var tmp = scaleBitmap(bitmap)
        //将图片二值化
        tmp = binarized(tmp)

        //将二值化的图片加载到内存中
        tmp.getPixels(intValues,
            0, tmp.width, 0, 0, tmp.width, tmp.height
        )
        val imgData = ByteBuffer.allocateDirect(
            mNumBytesPerChannel * mDimBatchSize * mDimImgWidth * mDimImgHeight * mDimPixelSize
        )
        imgData.order(ByteOrder.nativeOrder())
        imgData.rewind()

        //将图片转换成像素实数数据
        var pixel = 0
        for (i in 0 until mDimImgWidth) {
            for (j in 0 until mDimImgHeight) {
                var value = intValues[pixel++]
                mModelConfig.addImgValue(imgData, value)
            }
        }
        return imgData
    }

    /**
     * 将图片二值化处理
     * 转换成二值图像
     * @param bmp
     * @return
     */
    fun binarized(bmp: Bitmap): Bitmap {
        val width = bmp.width
        val height = bmp.height
        val pixels = IntArray(width * height)
        //将图片的像素加载到数组中
        bmp.getPixels(pixels, 0, width, 0, 0, width, height)
        var alpha = 0xFF shl 24
        for (i in 0 until height) {
            for (j in 0 until width) {
                val grey = pixels[width * i + j]
                // 分离三原色
                alpha = grey and -0x1000000 shr 24
                var red = grey and 0x00FF0000 shr 16
                var green = grey and 0x0000FF00 shr 8
                var blue = grey and 0x000000FF
                val tmp = 180
                red = if (red > tmp) 255 else 0
                blue = if (blue > tmp) 255 else 0
                green = if (green > tmp) 255 else 0
                pixels[width * i + j] = alpha shl 24 or (red shl 16) or (green shl 8) or blue
                if (pixels[width * i + j] == -1) {
                    pixels[width * i + j] = -1
                } else {
                    pixels[width * i + j] = -16777216
                }
            }
        }
        // 新建图片
        val newBmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
        // 设置图片数据
        newBmp.setPixels(pixels, 0, width, 0, 0, width, height)
        return newBmp
    }

    /**
     * 将图片调整到规定的大小28x28
     */
    fun scaleBitmap(bmp: Bitmap?): Bitmap {
        return Bitmap.createScaledBitmap(bmp!!, mDimImgWidth, mDimImgHeight, true)
    }

    /**
     * 分类处理
     */
    fun doClassify(bitmap: Bitmap?): String? {
        // 将Bitmap图片转换成TFLite翻译器的可读的ByteBuffer
        val imgData = convertBitmapToByteBuffer(bitmap)

        // do run interpreter
        val startTime = System.nanoTime()
        mTFLite.run(imgData, mLabelProbArray)
        val endTime = System.nanoTime()
        Log.i(TAG, String.format(
                "运行识别的时间: %f ms",
                (endTime - startTime).toFloat() / 1000000.0f
            )
        )

        // 生成并返回结果
        return printTopKLabels()
    }

    /**
     * 打印检测排序在前几位的标签,并作为结果显示在UI界面中。
     */
    fun printTopKLabels(): String? {
        for (i in 0..9) {
            mSortedLabels.add(
                AbstractMap.SimpleEntry(
                    labels[i],
                    mLabelProbArray[0][i]
                )
            )
            if (mSortedLabels.size > RESULTS_TO_SHOW) {
                mSortedLabels.poll()
            }
        }
        val textToShow = StringBuffer()
        val size = mSortedLabels.size
        for (i in 0 until size) {
            val label = mSortedLabels.poll()
            textToShow.insert(0, String.format("\n%s   %4.8f", label.key, label.value))
        }
        return textToShow.toString()
    }

}

Define the main activityMainActivity

In the main activity, the following operations are mainly processed:
(1) Select a picture from the gallery
(2) Use the image classifier to detect the content in the picture and determine which label it is from the FashionMnist data set
(3) Display the detection results on the mobile terminal displayed in the GUI interface.

class MainActivity : AppCompatActivity() {
    private lateinit var binding: ActivityMainBinding
    val RequestCameraCode = 1
    val TAG = "FashionMNIST"
    companion object{
        var mIsFloat = true
    }
    private var bitmap: Bitmap? = null
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)

        //生成视图绑定对象
        binding = ActivityMainBinding.inflate(layoutInflater)
        //设置视图的根视图
        setContentView(binding.root)

        binding.imageView.setOnClickListener {
            val intent = Intent()
            intent.type = "image/*"
            intent.action = Intent.ACTION_GET_CONTENT
            startActivityForResult(intent,RequestCameraCode)
        }

        val spinnerAdapter = ArrayAdapter<String>(this,android.R.layout.simple_spinner_item,getChoices())
        binding.typeSpinner.adapter = spinnerAdapter
        
        binding.typeSpinner.onItemSelectedListener = object : OnItemSelectedListener {
            override fun onItemSelected(
                parent: AdapterView<*>?,
                view: View,
                position: Int,
                id: Long
            ) {
                mIsFloat = position == 0
            }

            override fun onNothingSelected(parent: AdapterView<*>?) {}
        }
    }

    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        super.onActivityResult(requestCode, resultCode, data)
        if(resultCode == RESULT_OK && requestCode == RequestCameraCode){
            val uri = data?.data
            try{
                //从图库中读取图片
                var bitmap = BitmapFactory.decodeStream(contentResolver.openInputStream(uri!!))
                //在图像视图ImageView中显示图片
                binding.imageView.setImageBitmap(bitmap)
                //判断模型类型
                val config = when(mIsFloat){
                    true->ModelConfigFactory.FLOAT_SAVED_MODEL
                    else->ModelConfigFactory.QUANT_SAVED_MODEL
                }
                //根据模型类型创建图像识别器
                val classifier = ImageClassifier(config,this)
                //检测并判断图像的类别
                val result = classifier.doClassify(bitmap)
                binding.labelTxt.text = result
                binding.tipTxt.visibility = View.GONE
            }catch(e: FileNotFoundException){
                Log.d(TAG,"没有找到指定的图像文件")
            }catch(e: IOException){
                Log.e(TAG,"初始化图像识别器失败")
            }

        }
    }
    /**
     * 返回可用模型的名称
     */
    private fun getChoices()= resources.getStringArray(R.array.model_names)

}

references

Li Xihan and others "Concise Tensorflow 2" People's Posts and Telecommunications Press, Beijing P91-P96

Guess you like

Origin blog.csdn.net/userhu2012/article/details/121531888