一键抠图3:Android实现人像抠图 (Portrait Matting)

一键抠图3:Android实现人像抠图 (Portrait Matting)

目录

一键抠图3:Android实现人像抠图 (Portrait Matting)

1. 前言

2. 抠图算法

3. 模型Android部署

(1) 将Pytorch模型转换ONNX模型

(2) 将ONNX模型转换为TNN模型

(3) Android端上部署模型

(4) Android测试效果 

(5) 运行APP闪退:dlopen failed: library "libomp.so" not found

4.Android项目源码下载

5.人像抠图C++版本

6.人像抠图Python版本


1. 前言

这是一键抠图项目系列之《Android实现人像抠图 (Portrait Matting)》;本篇主要分享将Python训练后的matting模型部署到Android平台。我们将开发一个简易的、可实时运行的人像抠图Android Demo。Android版本人像抠图模型推理支持CPU和GPU加速,在GPU(OpenCL)加速下,可以达到头发细致级别的人像抠图效果,为了方便后续模型工程化和Android平台部署,项目提供高精度版本人像抠图和轻量化快速版人像抠图,并开发了Python/C++/Android多个版本;

先展示一下Android版本一键抠图效果:

模型选择 原图 高精度人像抠图 视频抠图

 Android Demo APP下载地址:https://download.csdn.net/download/guyuealian/63228759

尊重原创,转载请注明出处https://blog.csdn.net/guyuealian/article/details/134801795 


更多项目《一键抠图》系列文章请参考:

  1. 一键抠图1:Python实现人像抠图 (Portrait Matting)https://blog.csdn.net/guyuealian/article/details/134784803
  2. 一键抠图2:C/C++实现人像抠图 (Portrait Matting)https://blog.csdn.net/guyuealian/article/details/134790532
  3. 一键抠图3:Android实现人像抠图 (Portrait Matting)https://blog.csdn.net/guyuealian/article/details/134801795


2. 抠图算法

基于深度学习的Matting分为两大类:

  • 一种是基于辅助信息输入。即除了原图和标注图像外,还需要输入其他的信息辅助预测。最常见的辅助信息是Trimap,即将图片划分为前景,背景及过度区域三部分。另外也有以背景或交互点作为辅助信息。

  • 一种是不依赖任何辅助信息,直接对Alpha进行预测。如本博客复现的MODNet

第一种方法,需要加入辅助信息,而辅助信息一般较难获取,这也限制其应用,为了提升Matting的应用性,针对Portrait Matting领域MODNet摒弃了辅助信息,直接实现Alpha预测,实现了实时Matting,极大提升了基于深度学习Matting的应用价值。

更多抠图算法(Matting),请参考我的一篇博客《图像抠图Image Matting算法调研》:

图像抠图Image Matting算法调研_image matting调研-CSDN博客文章浏览阅读4.3k次,点赞8次,收藏68次。1.Trimap和StrokesTrimap和Strokes都是一种静态图像抠图算法,现有静态图像抠图算法均需对给定图像添加手工标记以增加抠图问题的额外约束。Trimap,三元图,是对给定图像的一种粗略划分,即将给定图像划分为前景、背景和待求未知区域Strokes则采用涂鸦的方式在图像上随意标记前景和背景区域,剩余未标记部分则为待求的未知区域Trimap是最常用的先验知识,多数抠图算法采用了Trimap作为先验知识,顾名思义Trimap是一个三元图,每个像素取值为{0,128,..._image matting调研https://blog.csdn.net/guyuealian/article/details/119648686可能,有小伙伴搞不清楚分割(segmentation)和抠图(matting)有什么区别,我这里简单说明一下:

  •  分割(segmentation):从深度学习的角度来说,分割本质是像素级别的分类任务,其损失函数最简单的莫过于是交叉熵CrossEntropyLoss(当然也可以是Focal Loss,IOU Loss,Dice Loss等);对于前景和背景分割任务,输出Mask的每个像素要么是0,要么是1。如果拿去直接做图像融合,就很不自然,Mask边界很生硬,这时就需要使用抠图算法了
  •  抠图(matting): 而抠图本质是一种回归任务,其损失函数可以是MSE Loss,L1 Loss,L2 Loss等,对于前景和背景抠图任务,输出Mask的每个像素是0~1之间的连续值,可看作是对图像透明通道(Alpha)的回归预测。可以用公式表示为C = αF + (1-α)B ,其中α(不透明度)、F(前景色)和B(背景色),alpha是[0, 1]之间的连续值,可以理解为像素属于前景的概率。在人像分割任务中,alpha只能取0或1,而抠图任务中,alpha可取[0, 1]之间的连续值,
  • 本质上就是一句话:分割是分类任务,而抠图是回归任务。

3. 模型Android部署

目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行Android端上部署。部署流程可分为四步:训练模型->将模型转换ONNX模型->将ONNX模型转换为TNN模型->Android端上部署TNN模型。

(1) 将Pytorch模型转换ONNX模型

训练好模型后,你需要先将Pytorch模型转换为ONNX模型,并使用onnx-simplifier简化网络结构,Python版本的已经提供了ONNX转换脚本,终端输入命令如下:

# 导出ONNX模型
python export.py --model_type "modnet" --model_file "work_space/modnet_416/model/best_model.pth"

​GitHub: https://github.com/daquexian/onnx-simplifier
Install:  pip3 install onnx-simplifier 

(2) 将ONNX模型转换为TNN模型

目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行Android端上部署

TNN转换工具:

​​

转换成功后,会生成两个文件(*.tnnproto和*.tnnmodel) ,下载下来后面会用到

(3) Android端上部署模型

项目Android部署框架采用TNN,支持多线程CPU和GPU加速推理,在普通手机上可以实时处理。项目Android源码,核心算法均采用C++实现,上层通过JNI接口调用。

如果你想在这个Android Demo部署你自己训练的模型,你可将训练好的Pytorch模型转换ONNX ,再转换成TNN模型,然后把TNN模型代替你模型即可。 

  • 这是项目Android源码JNI接口 ,Java部分

matting接口:实现基本的人像构图Matting功能
fusion接口:实现人像构图Matting,并与背景图进行融合
mattingFusion接口:人像构图Matting,并与背景图进行融合(会返回mask)

package com.cv.tnn.model;
 
import android.graphics.Bitmap;
 
public class Detector {
 
    static {
        System.loadLibrary("tnn_wrapper");
    }
 
 
    /***
     * 初始化检测模型
     * @param proto: TNN *.tnnproto文件文件名(含后缀名)
     * @param model: TNN *.tnnmodel文件文件名(含后缀名)
     * @param root:模型文件的根目录,放在assets文件夹下
     * @param model_type:模型类型
     * @param num_thread:开启线程数
     * @param useGPU:是否使用GPU
     */
    public static native void init(String proto, String model, String root, int model_type, int num_thread, boolean useGPU);
 
    /***
     * 缩放图片
     * @param bitmap
     * @param resize_width
     * @param resize_height
     * @return
     */
    public static Bitmap resizeBitmap(Bitmap bitmap, int resize_width, int resize_height) {
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        if (resize_width <= 0 && resize_height <= 0) {
            return bitmap;
        } else if (resize_height <= 0) {
            resize_height = height * resize_width / width;
        } else if (resize_width <= 0) {
            resize_width = width * resize_height / height;
        }
        Bitmap dst = Bitmap.createScaledBitmap(bitmap, resize_width, resize_height, false);
        return dst;
    }
 
 
    /***
     * 人像构图Matting
     * @param bitmap 输入图像(bitmap),ARGB_8888格式
     * @param mask   输出Mask图像(bitmap),ARGB_8888格式,调用前需要createBitmap初始化大小,如
     *               Bitmap.createBitmap(Width, Height, Bitmap.Config.ARGB_8888);
     * @return
     */
    public static native void matting(Bitmap bitmap, Bitmap mask);
 
 
    /***
     * 人像构图Matting,并与背景图进行融合
     * @param bitmap 输入图像(bitmap),ARGB_8888格式
     * @param bgmap  输入背景图像(bitmap),ARGB_8888格式,可任意大小的图像
     * @param fusion 输出与背景融合后图像,调用前需要createBitmap初始化大小,ARGB_8888格式
     */
    public static native void fusion(Bitmap bitmap, Bitmap bgmap, Bitmap fusion);
 
    /***
     * 人像构图Matting,并与背景图进行融合
     * @param bitmap 输入图像(bitmap),ARGB_8888格式
     * @param bgmap  输入背景图像(bitmap),ARGB_8888格式,可任意大小的图像
     * @param fusion 输出与背景融合后图像,调用前需要createBitmap初始化大小,ARGB_8888格式
     * @param mask   输出Mask图像(bitmap),调用前需要createBitmap初始化大小,ARGB_8888格式
     * @return
     */
    public static native void mattingFusion(Bitmap bitmap, Bitmap bgmap, Bitmap fusion, Bitmap mask);
 
 
}
  • 这是Android项目源码JNI接口 ,C++部分
#include <jni.h>
#include <string>
#include <fstream>
#include "src/segment.h"
#include "src/object_detection.h"
#include "src/Types.h"
#include "debug.h"
#include "android_utils.h"
#include "opencv2/opencv.hpp"

using namespace dm;
using namespace vision;

static Segment *segment = nullptr;
static ObjectDetection *detector = nullptr;


JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *reserved) {
    return JNI_VERSION_1_6;
}

JNIEXPORT void JNI_OnUnload(JavaVM *vm, void *reserved) {

}


extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_init(JNIEnv *env,
                                    jclass clazz,
                                    jstring proto,
                                    jstring model,
                                    jstring root,
                                    jint model_type,
                                    jint num_thread,
                                    jboolean use_gpu) {
    if (segment != nullptr) {
        delete segment;
        segment = nullptr;
    }
    std::string parent = env->GetStringUTFChars(root, 0);
    std::string proto_file = parent + env->GetStringUTFChars(proto, 0);
    std::string model_file = parent + env->GetStringUTFChars(model, 0);
    DeviceType device = use_gpu ? GPU : CPU;
    LOGW("parent     : %s", parent.c_str());
    LOGW("useGPU     : %d", use_gpu);
    LOGW("device_type: %d", device);
    LOGW("model_type : %d", model_type);
    LOGW("num_thread : %d", num_thread);
    SegmentParam model_param = SEG_MODEL_TYPE[model_type];//模型参数
    segment = new Segment(model_file,
                          proto_file,
                          model_param,
                          num_thread,
                          device);

}


extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_matting(JNIEnv *env, jclass clazz, jobject bitmap,
                                       jobject out_mask) {
    cv::Mat image;//bgr
    cv::Mat bg;//bgr
    BitmapToMatrix(env, bitmap, image);
    cv::Mat mask;
    cv::Mat fusion;
    // 检测人像分割
    segment->detect(image, mask);
    // 返回Mask
    MatrixToBitmap(env, mask, out_mask);
}



extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_fusion(JNIEnv *env, jclass clazz,
                                      jobject bitmap,
                                      jobject bgmap,
                                      jobject out_fusion) {
    cv::Mat image;//bgr
    cv::Mat bg;//bgr
    BitmapToMatrix(env, bitmap, image);
    BitmapToMatrix(env, bgmap, bg);
    cv::Mat mask;
    cv::Mat fusion;
    // 检测人像分割
    segment->detect(image, mask);
    // 将matte与背景bg进行融合fusion
    image_fusion(image, mask, fusion, bg);
    // 融合fusion图像
    MatrixToBitmap(env, fusion, out_fusion);
}



extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_mattingFusion(JNIEnv *env, jclass clazz,
                                             jobject bitmap,
                                             jobject bgmap,
                                             jobject out_fusion,
                                             jobject out_mask) {
    cv::Mat image;//bgr
    cv::Mat bg;//bgr
    BitmapToMatrix(env, bitmap, image);
    BitmapToMatrix(env, bgmap, bg);
    cv::Mat mask;
    cv::Mat fusion;
    // 检测人像分割
    segment->detect(image, mask);
    // 将matte与背景bg进行融合fusion
    image_fusion(image, mask, fusion, bg);
    // 融合fusion图像
    MatrixToBitmap(env, fusion, out_fusion);
    MatrixToBitmap(env, mask, out_mask);
}



extern "C"
JNIEXPORT jobjectArray JNICALL
Java_com_cv_tnn_model_Detector_detect(JNIEnv *env, jclass clazz, jobject bitmap,
                                      jfloat score_thresh, jfloat iou_thresh) {
    cv::Mat bgr;
    BitmapToMatrix(env, bitmap, bgr);
    int src_h = bgr.rows;
    int src_w = bgr.cols;
    // 检测区域为整张图片的大小
    FrameInfo resultInfo;
    // 开始检测
    if (detector != nullptr) {
        detector->detect(bgr, &resultInfo, score_thresh, iou_thresh);
    } else {
        ObjectInfo objectInfo;
        objectInfo.x1 = 0;
        objectInfo.y1 = 0;
        objectInfo.x2 = 84;
        objectInfo.y2 = 84;
        objectInfo.label = 0;
        resultInfo.info.push_back(objectInfo);
    }

    int nums = resultInfo.info.size();
    LOGW("object nums: %d\n", nums);

    auto BoxInfo = env->FindClass("com/cv/tnn/model/FrameInfo");
    auto init_id = env->GetMethodID(BoxInfo, "<init>", "()V");
    auto box_id = env->GetMethodID(BoxInfo, "addBox", "(FFFFIF)V");
    auto ky_id = env->GetMethodID(BoxInfo, "addKeyPoint", "(FFF)V");
    jobjectArray ret = env->NewObjectArray(resultInfo.info.size(), BoxInfo, nullptr);
    for (int i = 0; i < nums; ++i) {
        auto info = resultInfo.info[i];
        env->PushLocalFrame(1);
        //jobject obj = env->AllocObject(BoxInfo);
        jobject obj = env->NewObject(BoxInfo, init_id);
        // set bbox
        //LOGW("rect:[%f,%f,%f,%f] label:%d,score:%f \n", info.rect.x,info.rect.y, info.rect.w, info.rect.h, 0, 1.0f);
        env->CallVoidMethod(obj, box_id, info.x1, info.y1, info.x2 - info.x1, info.y2 - info.y1,
                            info.label, info.score);
        // set keypoint
        for (const auto &kps : info.landmarks) {
            //LOGW("point:[%f,%f] score:%f \n", lm.point.x, lm.point.y, lm.score);
            env->CallVoidMethod(obj, ky_id, (float) kps.x, (float) kps.y, 1.0f);
        }
        obj = env->PopLocalFrame(obj);
        env->SetObjectArrayElement(ret, i, obj);
    }
    return ret;
}

(4) Android测试效果 

实际使用中,建议你:

  • 背景越单一,抠图的效果越好,背景越复杂,抠图效果越差;建议你实际使用中,找一比较单一的背景,如墙面,天空等
  • 上半身抠图的效果越好,下半身或者全身抠图效果较差;本质上这是数据的问题,因为训练数据70%都是只有上半身的
  • 白种人抠图的效果越好,黑人和黄种人抠图效果较差;这也是数据的问题,因为训练数据大部分都是隔壁的老外

下图是高精度版本人像抠图和快速人像构图的测试效果,相对而言,高精度版本人像抠图可以精细到发丝级别的抠图效果;而快速人像构图目前仅能实现基本的抠图效果:

原图  Mask图像  融合图像

​​ ​​ ​​

(5) 运行APP闪退:dlopen failed: library "libomp.so" not found

参考解决方法:
解决dlopen failed: library “libomp.so“ not found_PKing666666的博客-CSDN博客_dlopen failed

 Android SDK和NDK相关版本信息,请参考: 

 


4.Android项目源码下载

 Android Demo APP下载地址:https://download.csdn.net/download/guyuealian/63228759

Android项目源码下载地址:一键抠图Portrait Matting人像抠图 (C++和Android源码)

整套Android项目源码内容包含:

  1.  Android版本人像抠图算法,支持CPU和GPU
  2. 提供高精度版本人像抠图,可以达到精细到发丝级别的抠图效果(Android GPU 150ms,  CPU 500ms左右)
  3. 提供轻量化快速版人像抠图,满足基本的人像抠图效果,可以在Android达到实时的抠图效果(Android GPU 60ms,  CPU 140ms左右)
  4. Android Demo支持图片,视频,摄像头测试
  5. 所有依赖库都已经配置好,可直接build运行,若运行出现闪退,请参考dlopen failed: library “libomp.so“ not found 解决。


5.人像抠图C++版本

一键抠图2:C/C++实现人像抠图 (Portrait Matting)https://blog.csdn.net/guyuealian/article/details/134790532


6.人像抠图Python版本

一键抠图1:Python实现人像抠图 (Portrait Matting)https://blog.csdn.net/guyuealian/article/details/134784803

猜你喜欢

转载自blog.csdn.net/guyuealian/article/details/134801795