Android deployment tensorflow lite
Follow the instructions on the official website and add the following configuration to the configuration in the project's module build file build.gradle:
implementation 'org.tensorflow:tensorflow-lite:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'
implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
android{
aaptOptions {
noCompress "tflite"
}
defaultConfig {
ndk {
abiFilters 'armeabi-v7a', 'arm64-v8a'
}
}
}
Import model resource resources
Create the model model.tflite created in the article " About converting Tesorflow's SavedModel model into a tflite model " and import it into the assets directory of the Android project.
Define the model basic configuration class BaseModelConfig
/**
* 定义模型的基本配置类
*/
public abstract class BaseModelConfig{
//每通道处理的字节数
var numBytesPerChannel:Int = 0
//定义批处理的个数
var dimBatchSize:Int = 0
//定义像素个数
var dimPixelSize:Int = 0
//定义图片的宽度
var dimImgWidth:Int = 0
//定义图片的高度
var dimImgHeight:Int = 0
//定义平均差
var imageMean=0
//定义图片的标准差
var imageSTD:Float = 0.0F
//定义模型的名称
lateinit var modelName:String
constructor() : super() {
setConfigs()
}
/**
* 将像素值转换成ByteBuffer
* 增加图片的值
*/
public abstract fun addImgValue(buffer: ByteBuffer,pixel:Int)
/**
* 配置
*/
public abstract fun setConfigs()
}
Define the FloatSavedModelConfig class
class FloatSavedModelConfig: BaseModelConfig() {
public override fun setConfigs() {
modelName="model.tflite"
numBytesPerChannel = 4
dimBatchSize = 1
dimPixelSize = 1
dimImgWidth = 28
dimImgHeight = 28
imageMean = 0
imageSTD = 255.0f
}
override fun addImgValue(imgData: ByteBuffer, pixel: Int) {
imgData.putFloat(((pixel and 0xFF) - imageMean) / imageSTD)
}
}
Create a factory class that configures model parameters
object ModelConfigFactory {
const val FLOAT_SAVED_MODEL = "float_saved_model"
const val QUANT_SAVED_MODEL = "quant_saved_model"
fun getModelConfig(model: String): BaseModelConfig? =
when(model) {
FLOAT_SAVED_MODEL-> FloatSavedModelConfig()
QUANT_SAVED_MODEL-> QuantSavedModelConfig()
else->null
}
}
Define image classifier
class ImageClassifier {
private val TAG = "FashionMNIST"
private val RESULTS_TO_SHOW = 3
lateinit var mTFLite: Interpreter
lateinit var mModelPath:String
var mNumBytesPerChannel = 0
var mDimBatchSize = 0
var mDimPixelSize = 0
var mDimImgWidth = 0
var mDimImgHeight = 0
lateinit var mModelConfig:BaseModelConfig
//定义标签检测的二维数组1x10
val mLabelProbArray = Array(1) {
FloatArray(
10
)
}
val labels = arrayListOf("T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子")
//定义检测结果保持到优先队列中
var mSortedLabels = PriorityQueue<Map.Entry<String, Float>>(
RESULTS_TO_SHOW) {
o1, o2 -> o1?.value!!.compareTo(o2?.value!!)
}
/**
* 配置参数
*/
private fun initConfig(config: BaseModelConfig) {
mModelConfig = config
mNumBytesPerChannel = config.numBytesPerChannel
mDimBatchSize = config.dimBatchSize
mDimPixelSize = config.dimPixelSize
mDimImgWidth = config.dimImgWidth
mDimImgHeight = config.dimImgHeight
mModelPath = config.modelName
}
constructor(modelConfig: String, activity: Activity) {
// 初始化分类器的相关参数
initConfig(ModelConfigFactory.getModelConfig(modelConfig)!!)
// 使用配置参数初始化翻译器
mTFLite = Interpreter(loadModelFile(activity)!!)
}
/**
* 在Assets中的模型文件映射到内存中
* */
private fun loadModelFile(activity: Activity): MappedByteBuffer? {
val fileDescriptor = activity.assets.openFd(mModelPath)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
/**
* 将图片数据写入到ByteBuffer,加载到内存中
* */
protected fun convertBitmapToByteBuffer(bitmap: Bitmap?): ByteBuffer {
val intValues = IntArray(mDimImgWidth * mDimImgHeight)
//调整要处理的图片为28x28
var tmp = scaleBitmap(bitmap)
//将图片二值化
tmp = binarized(tmp)
//将二值化的图片加载到内存中
tmp.getPixels(intValues,
0, tmp.width, 0, 0, tmp.width, tmp.height
)
val imgData = ByteBuffer.allocateDirect(
mNumBytesPerChannel * mDimBatchSize * mDimImgWidth * mDimImgHeight * mDimPixelSize
)
imgData.order(ByteOrder.nativeOrder())
imgData.rewind()
//将图片转换成像素实数数据
var pixel = 0
for (i in 0 until mDimImgWidth) {
for (j in 0 until mDimImgHeight) {
var value = intValues[pixel++]
mModelConfig.addImgValue(imgData, value)
}
}
return imgData
}
/**
* 将图片二值化处理
* 转换成二值图像
* @param bmp
* @return
*/
fun binarized(bmp: Bitmap): Bitmap {
val width = bmp.width
val height = bmp.height
val pixels = IntArray(width * height)
//将图片的像素加载到数组中
bmp.getPixels(pixels, 0, width, 0, 0, width, height)
var alpha = 0xFF shl 24
for (i in 0 until height) {
for (j in 0 until width) {
val grey = pixels[width * i + j]
// 分离三原色
alpha = grey and -0x1000000 shr 24
var red = grey and 0x00FF0000 shr 16
var green = grey and 0x0000FF00 shr 8
var blue = grey and 0x000000FF
val tmp = 180
red = if (red > tmp) 255 else 0
blue = if (blue > tmp) 255 else 0
green = if (green > tmp) 255 else 0
pixels[width * i + j] = alpha shl 24 or (red shl 16) or (green shl 8) or blue
if (pixels[width * i + j] == -1) {
pixels[width * i + j] = -1
} else {
pixels[width * i + j] = -16777216
}
}
}
// 新建图片
val newBmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
// 设置图片数据
newBmp.setPixels(pixels, 0, width, 0, 0, width, height)
return newBmp
}
/**
* 将图片调整到规定的大小28x28
*/
fun scaleBitmap(bmp: Bitmap?): Bitmap {
return Bitmap.createScaledBitmap(bmp!!, mDimImgWidth, mDimImgHeight, true)
}
/**
* 分类处理
*/
fun doClassify(bitmap: Bitmap?): String? {
// 将Bitmap图片转换成TFLite翻译器的可读的ByteBuffer
val imgData = convertBitmapToByteBuffer(bitmap)
// do run interpreter
val startTime = System.nanoTime()
mTFLite.run(imgData, mLabelProbArray)
val endTime = System.nanoTime()
Log.i(TAG, String.format(
"运行识别的时间: %f ms",
(endTime - startTime).toFloat() / 1000000.0f
)
)
// 生成并返回结果
return printTopKLabels()
}
/**
* 打印检测排序在前几位的标签,并作为结果显示在UI界面中。
*/
fun printTopKLabels(): String? {
for (i in 0..9) {
mSortedLabels.add(
AbstractMap.SimpleEntry(
labels[i],
mLabelProbArray[0][i]
)
)
if (mSortedLabels.size > RESULTS_TO_SHOW) {
mSortedLabels.poll()
}
}
val textToShow = StringBuffer()
val size = mSortedLabels.size
for (i in 0 until size) {
val label = mSortedLabels.poll()
textToShow.insert(0, String.format("\n%s %4.8f", label.key, label.value))
}
return textToShow.toString()
}
}
Define the main activityMainActivity
In the main activity, the following operations are mainly processed:
(1) Select a picture from the gallery
(2) Use the image classifier to detect the content in the picture and determine which label it is from the FashionMnist data set
(3) Display the detection results on the mobile terminal displayed in the GUI interface.
class MainActivity : AppCompatActivity() {
private lateinit var binding: ActivityMainBinding
val RequestCameraCode = 1
val TAG = "FashionMNIST"
companion object{
var mIsFloat = true
}
private var bitmap: Bitmap? = null
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
//生成视图绑定对象
binding = ActivityMainBinding.inflate(layoutInflater)
//设置视图的根视图
setContentView(binding.root)
binding.imageView.setOnClickListener {
val intent = Intent()
intent.type = "image/*"
intent.action = Intent.ACTION_GET_CONTENT
startActivityForResult(intent,RequestCameraCode)
}
val spinnerAdapter = ArrayAdapter<String>(this,android.R.layout.simple_spinner_item,getChoices())
binding.typeSpinner.adapter = spinnerAdapter
binding.typeSpinner.onItemSelectedListener = object : OnItemSelectedListener {
override fun onItemSelected(
parent: AdapterView<*>?,
view: View,
position: Int,
id: Long
) {
mIsFloat = position == 0
}
override fun onNothingSelected(parent: AdapterView<*>?) {}
}
}
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(requestCode, resultCode, data)
if(resultCode == RESULT_OK && requestCode == RequestCameraCode){
val uri = data?.data
try{
//从图库中读取图片
var bitmap = BitmapFactory.decodeStream(contentResolver.openInputStream(uri!!))
//在图像视图ImageView中显示图片
binding.imageView.setImageBitmap(bitmap)
//判断模型类型
val config = when(mIsFloat){
true->ModelConfigFactory.FLOAT_SAVED_MODEL
else->ModelConfigFactory.QUANT_SAVED_MODEL
}
//根据模型类型创建图像识别器
val classifier = ImageClassifier(config,this)
//检测并判断图像的类别
val result = classifier.doClassify(bitmap)
binding.labelTxt.text = result
binding.tipTxt.visibility = View.GONE
}catch(e: FileNotFoundException){
Log.d(TAG,"没有找到指定的图像文件")
}catch(e: IOException){
Log.e(TAG,"初始化图像识别器失败")
}
}
}
/**
* 返回可用模型的名称
*/
private fun getChoices()= resources.getStringArray(R.array.model_names)
}
references
Li Xihan and others "Concise Tensorflow 2" People's Posts and Telecommunications Press, Beijing P91-P96