Explore MediaPipe's portrait segmentation

MediaPipe is Google's open source computer vision processing framework, based on TensorFlow to train the model. The image segmentation module provides portrait segmentation, hair segmentation, and multi-class segmentation. This article mainly explores how to realize portrait segmentation. Of course, on the basis of portrait segmentation, we can do background replacement and background blurring.

Table of contents

1. Configuration parameters and models

1. Configuration parameters

2. Segmentation model

2.1 Portrait Segmentation Model

2.2 Hair Segmentation Model

2.3 Multi-Class Segmentation Model

2. Engineering configuration

3. Initialization work

1. Initialize portrait segmentation

2. Initialize the camera

4. Portrait Segmentation

1. Running portrait segmentation 

2. Draw portrait segmentation

5. Segmentation effect

1. Configuration parameters and models

1. Configuration parameters

The parameters of image segmentation include: operating mode, output category mask, output confidence mask, label language, and result callback, as shown in the following table:

parameter describe Ranges Defaults
running_mode

IMAGE: a single image

VIDEO: video frame

LIVE_STREAM: live stream

{IMAGE,VIDEO,

LIVE_STREAM}

IMAGE
output_category_mask output category mask Boolean false
output_confidence_mask output confidence mask Boolean true
display_names_locale language of label names Locale code in
result_callback Result callback (for LIVE_STREAM mode) N/A N/A

2. Segmentation model

Image segmentation models include deeplabv3, haird_segmenter, selfie_multiclass, selfie_segmenter. Among them, the model used for selfie portrait segmentation is selfie_segmenter. The relevant models are shown in the figure below:

2.1 Portrait Segmentation Model

Portrait segmentation outputs two types of results: 0 for background and 1 for portrait. Two shapes of models are provided, as shown in the figure below:

 

2.2 Hair Segmentation Model

Hair segmentation also outputs two types of results: 0 for background and 1 for hair. When we recognize the hair, we can recolor the hair or add special effects.

2.3 Multi-Class Segmentation Model

Multi-class segmentation includes: background, hair, body skin, face skin, clothes, other parts. The numerical correspondence is as follows:

0 - background
1 - hair
2 - body-skin
3 - face-skin
4 - clothes
5 - others (accessories)

2. Engineering configuration

Taking the Android platform as an example, first import MediaPipe related packages:

implementation 'com.google.mediapipe:tasks-vision:0.10.0'

Then run the task of downloading the model, and specify the path to save the model:

project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'

apply from: 'download_models.gradle'

There are 4 image segmentation models, which can be downloaded on demand:

task downloadSelfieSegmenterModelFile(type: Download) {
    src 'https://storage.googleapis.com/mediapipe-models/image_segmenter/' +
            'selfie_segmenter/float16/1/selfie_segmenter.tflite'
    dest project.ext.ASSET_DIR + '/selfie_segmenter.tflite'
    overwrite false
}

preBuild.dependsOn downloadSelfieSegmenterModelFile

3. Initialization work

1. Initialize portrait segmentation

The initialization of portrait segmentation mainly includes: setting the operating mode, loading the corresponding segmentation model, and configuring parameters. The sample code is as follows:

   fun setupImageSegmenter() {
        val baseOptionsBuilder = BaseOptions.builder()
        // 设置运行模式
        when (currentDelegate) {
            DELEGATE_CPU -> {
                baseOptionsBuilder.setDelegate(Delegate.CPU)
            }
            DELEGATE_GPU -> {
                baseOptionsBuilder.setDelegate(Delegate.GPU)
            }
        }
        // 加载对应的分割模型
        when(currentModel) {
            MODEL_DEEPLABV3 -> { //DeepLab V3
                baseOptionsBuilder.setModelAssetPath(MODEL_DEEPLABV3_PATH)
            }
            MODEL_HAIR_SEGMENTER -> { // 头发分割
                baseOptionsBuilder.setModelAssetPath(MODEL_HAIR_SEGMENTER_PATH)
            }
            MODEL_SELFIE_SEGMENTER -> { // 人像分割
                baseOptionsBuilder.setModelAssetPath(MODEL_SELFIE_SEGMENTER_PATH)
            }
            MODEL_SELFIE_MULTICLASS -> { // 多类分割
                baseOptionsBuilder.setModelAssetPath(MODEL_SELFIE_MULTICLASS_PATH)
            }
        }

        try {
            val baseOptions = baseOptionsBuilder.build()
            val optionsBuilder = ImageSegmenter.ImageSegmenterOptions.builder()
                .setRunningMode(runningMode)
                .setBaseOptions(baseOptions)
                .setOutputCategoryMask(true)
                .setOutputConfidenceMasks(false)
            // 检测结果异步回调
            if (runningMode == RunningMode.LIVE_STREAM) {
                optionsBuilder.setResultListener(this::returnSegmentationResult)
                    .setErrorListener(this::returnSegmentationHelperError)
            }

            val options = optionsBuilder.build()
            imagesegmenter = ImageSegmenter.createFromOptions(context, options)
        } catch (e: IllegalStateException) {
            imageSegmenterListener?.onError(
                "Image segmenter failed to init, error:${e.message}")
        } catch (e: RuntimeException) {
            imageSegmenterListener?.onError(
                "Image segmenter failed to init. error:${e.message}", GPU_ERROR)
        }
    }

2. Initialize the camera

The initialization steps of the camera include: obtaining CameraProvider, setting the preview aspect ratio, configuring image analysis parameters, binding the camera life cycle, and associating SurfaceProvider.

    private fun setUpCamera() {
        val cameraProviderFuture =
            ProcessCameraProvider.getInstance(requireContext())
        cameraProviderFuture.addListener(
            {
                // 获取CameraProvider
                cameraProvider = cameraProviderFuture.get()
                // 绑定camera
                bindCamera()
            }, ContextCompat.getMainExecutor(requireContext())
        )
    }

    private fun bindCamera() {
        val cameraProvider = cameraProvider ?: throw IllegalStateException("Camera init failed.")
        val cameraSelector = CameraSelector.Builder().requireLensFacing(cameraFacing).build()

        // 预览宽高比设置为4:3
        preview = Preview.Builder().setTargetAspectRatio(AspectRatio.RATIO_4_3)
            .setTargetRotation(fragmentCameraBinding.viewFinder.display.rotation)
            .build()

        // 配置图像分析的参数
        imageAnalyzer =
            ImageAnalysis.Builder().setTargetAspectRatio(AspectRatio.RATIO_4_3)
                .setTargetRotation(fragmentCameraBinding.viewFinder.display.rotation)
                .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
                .setOutputImageFormat(ImageAnalysis.OUTPUT_IMAGE_FORMAT_RGBA_8888)
                .build()
                .also {
                    it.setAnalyzer(backgroundExecutor!!) { image ->
                        imageSegmenterHelper.segmentLiveStreamFrame(image,
                            cameraFacing == CameraSelector.LENS_FACING_FRONT)
                    }
                }

        cameraProvider.unbindAll()

        try {
            // 绑定camera生命周期
            camera = cameraProvider.bindToLifecycle(
                this, cameraSelector, preview, imageAnalyzer)
            // 关联SurfaceProvider
            preview?.setSurfaceProvider(fragmentCameraBinding.viewFinder.surfaceProvider)
        } catch (exc: Exception) {
            Log.e(TAG, "Use case binding failed", exc)
        }
    }

4. Portrait Segmentation

1. Running portrait segmentation 

Taking the LIVE_STREAM mode of the camera as an example, first copy the image data, then image processing: rotation and mirroring, then convert the Bitmap object to MPImage, and finally perform portrait segmentation. The sample code is as follows:

    fun segmentLiveStreamFrame(imageProxy: ImageProxy, isFrontCamera: Boolean) {
        val frameTime    = SystemClock.uptimeMillis()
        val bitmapBuffer = Bitmap.createBitmap(imageProxy.width,
            imageProxy.height, Bitmap.Config.ARGB_8888)

        // 拷贝图像数据
        imageProxy.use {
            bitmapBuffer.copyPixelsFromBuffer(imageProxy.planes[0].buffer)
        }

        val matrix = Matrix().apply {
            // 旋转图像
            postRotate(imageProxy.imageInfo.rotationDegrees.toFloat())
            // 如果是前置camera,需要左右镜像
            if(isFrontCamera) {
                postScale(-1f, 1f, imageProxy.width.toFloat(), imageProxy.height.toFloat())
            }
        }

        imageProxy.close()

        val rotatedBitmap = Bitmap.createBitmap(bitmapBuffer, 0, 0,
            bitmapBuffer.width, bitmapBuffer.height, matrix, true)
        // 转换Bitmap为MPImage
        val mpImage = BitmapImageBuilder(rotatedBitmap).build()
        // 执行人像分割
        imagesegmenter?.segmentAsync(mpImage, frameTime)
    }

2. Draw portrait segmentation

First mark the detected background color, then calculate the zoom factor, and actively trigger the draw operation:

    fun setResults(
        byteBuffer: ByteBuffer,
        outputWidth: Int,
        outputHeight: Int) {
        val pixels = IntArray(byteBuffer.capacity())
        for (i in pixels.indices) {
            // Deeplab使用0表示背景,其他标签为1-19. 所以这里使用20种颜色
            val index = byteBuffer.get(i).toUInt() % 20U
            val color = ImageSegmenterHelper.labelColors[index.toInt()].toAlphaColor()
            pixels[i] = color
        }
        val image = Bitmap.createBitmap(pixels, outputWidth, outputHeight, Bitmap.Config.ARGB_8888)
        // 计算缩放系数
        val scaleFactor = when (runningMode) {
            RunningMode.IMAGE,
            RunningMode.VIDEO -> {
                min(width * 1f / outputWidth, height * 1f / outputHeight)
            }
            RunningMode.LIVE_STREAM -> {
                max(width * 1f / outputWidth, height * 1f / outputHeight)
            }
        }

        val scaleWidth = (outputWidth * scaleFactor).toInt()
        val scaleHeight = (outputHeight * scaleFactor).toInt()

        scaleBitmap = Bitmap.createScaledBitmap(image, scaleWidth, scaleHeight, false)
        invalidate()
    }

Finally, the draw function for drawing is executed, and the canvas is called to draw the bitmap:

    override fun draw(canvas: Canvas) {
        super.draw(canvas)
        scaleBitmap?.let {
            canvas.drawBitmap(it, 0f, 0f, null)
        }
    }

5. Segmentation effect

The essence of portrait segmentation is to separate the portrait from the background. The final rendering is as follows: 

 

Guess you like

Origin blog.csdn.net/u011686167/article/details/131400803