美国高通 Snapdragon Neural Processing Engine SDK (SNPE) 系列 (1):用户自定义层JNI实现

转自:https://blog.csdn.net/guvcolie/article/details/77937786

        Snapdragon Neural Processing Engine SDK是美国高通公司出品的神经网络处理引擎(SNPE),可运行于搭载了高通Zeroth机器智能平台的820芯片处理器,开发者可以在SNPE上搭建自己的深度学习网络模型。更详细的介绍可以登录高通SNPE相关网页了解:https://developer.qualcomm.com/software/snapdragon-neural-processing-engine 。介绍到此为止。        

         最近在用snpe做一个项目,将tensorflow的pb模型转为snpe的dlc模型,并利用snpe将模型运行在骁龙GPU上。但不得不说,snpe目前做的还不是非常完美,有些需要定制化的神经网络层可能在原生snpe中没有提供。但还好高通提供了用户定义层(UDL)功能,通过回调函数可以自定义算子,并通过重编译C++代码将自定义文件编译到可执行文件中。如果开发就是使用的C++,那比较容易实现用户定义层,但如果是运行在Android上就比较麻烦了,上层java代码需要通过JNI来调用snpe原生的C++编译好的.so文件,因为用户定义层的代码是不可能预先编译到snpe原生.so文件中的,所以用snpe提供的Java API是无法获得用户定义层的功能的,所以,必须重新开发SNPE的JNI。        

       出于涉密考虑,用户定义层的代码我不会展示,只展示JNI的骨架代码,关于用户定义层(UDL)的实现可以参阅SNPE的相关文档。        

Java层的代码: 

public class SnpeController {
 
    public enum Runtime {
        CPU,
        GPU,
        DSP
    }
 
    private int height = 0;  
    private int width = 0;  
    private int channel = 0;  
    private Runtime runtimeMode = Runtime.GPU;  
    private String loggerDir = "";  
    private String modelFilePath = "";  
    private long containerPointer = 0; 
    private long snpePointer = 0; 
    private float[] inputData = null;  
    private int inputLength = 0;  
    private float[] outputData = null;  
 
    public SnpeController(int imageHeight,
                          int imageWidth,
                          int imageChannel,
                          String modelPath) {
        this(imageHeight, imageWidth, imageChannel, modelPath, Runtime.CPU, "/home/mi/snpe_log_dir/");
    }
 
    public SnpeController(int imageHeight,
                          int imageWidth,
                          int imageChannel,
                          String modelPath,
                          Runtime runtimeMode,
                          String loggerDir) {
        setImageShape(imageHeight, imageWidth, imageChannel);
        setModelPath(modelPath);
        setRuntime(runtimeMode);
        setLoggerDir(loggerDir);
    }
 
    public void setImageShape(int height,
                              int width,
                              int channel){
        this.height = height;
        this.width = width;
        this.channel = channel;
    }
 
    public void setModelPath(String modelPath){
        modelFilePath = modelPath;
    }
 
    public void setRuntime(Runtime runtime){
        runtimeMode = runtime;
    }
 
    public void setLoggerDir(String path){
        loggerDir = path;
    }
 
    public boolean initModel(){
        if(!checkStatus())
            return false;
        long modelPointer = initContainer(modelFilePath);
        if (modelPointer == 0)
            return false;
        containerPointer = modelPointer;
        int runtime = 0;
        switch(runtimeMode){
            case CPU:
                runtime = 0;
                break;
            case GPU:
                runtime = 1;
                break;
            case DSP:
                runtime = 2;
                break;
            default:
                break;
        }
        long snpePointer = initSnpe(containerPointer, runtime);
        if (snpePointer == 0)
            return false;
        this.snpePointer = snpePointer;
        return true;
    }
 
    public boolean input(float[] imageData,
                         int length){
        if (height * width * channel != length)
            return false;
        inputData = imageData;
        inputLength = length;
        return true;
    }
 
    public boolean forward(){
        if (inputData == null || inputLength <= 0)
            return false;
        outputData = exec(snpePointer, inputData, inputLength);
        inputData = null;
        inputLength = 0;
        return true;
    }
 
    public float[] output(){
        return outputData;
    }
 
    public void release(){
        if (containerPointer != 0 || snpePointer != 0)
            releaseSource(containerPointer, snpePointer);
        inputData = null;
        outputData = null;
    }
 
    private boolean checkStatus(){
        if (height <= 0 || width <= 0 || channel <= 0) {
            System.out.println("image shape value is illegal!");
            return false;
        }
        return true;
    }
 
 
   
    private native long initContainer(String dlcFilePath);
    private native long initSnpe(long containerPointer, int runtimeMode);
    private native float[] exec(long snpePointer, float[] data, int length);
    private native void releaseSource(long containerPointer, long snpePointer);
 
    static {
        try {
            System.loadLibrary("SnpeController");
        } catch (UnsatisfiedLinkError e) {
            e.printStackTrace();
        }
    }
}

通过javac、javah命令,可以生成C++的.H头文件: -

/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class SnpeController_Runtime */
 
#ifndef _Included_SnpeController_Runtime
#define _Included_SnpeController_Runtime
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __cplusplus
}
#endif
#endif
 
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class SnpeController */
 
#ifndef _Included_SnpeController
#define _Included_SnpeController
#ifdef __cplusplus
extern "C" {
#endif
/*
 * Class:     SnpeController
 * Method:    initContainer
 * Signature: (Ljava/lang/String;)J
 */
JNIEXPORT jlong JNICALL Java_SnpeController_initContainer
  (JNIEnv *, jobject, jstring);
 
/*
 * Class:     SnpeController
 * Method:    initSnpe
 * Signature: (JI)J
 */
JNIEXPORT jlong JNICALL Java_SnpeController_initSnpe
  (JNIEnv *, jobject, jlong, jint);
 
/*
 * Class:     SnpeController
 * Method:    exec
 * Signature: (J[FI)[F
 */
JNIEXPORT jfloatArray JNICALL Java_SnpeController_exec
  (JNIEnv *, jobject, jlong, jfloatArray, jint);
 
/*
 * Class:     SnpeController
 * Method:    releaseSource
 * Signature: (JJ)V
 */
JNIEXPORT void JNICALL Java_SnpeController_releaseSource
  (JNIEnv *, jobject, jlong, jlong);
 
#ifdef __cplusplus
}
#endif
#endif

根据头文件,对相关函数进行实现: 

#include "ABCDE.h"
 
using namespace std;
 
namespace udlexample
{
zdl::DlSystem::IUDL * MyUDLFactory(void * cookie, const zdl::DlSystem::UDLContext * c) {
    ...
    if (type == "XXX") {
        return new udls::XXXLayer(*c);
    } else if (type == "YYY") {
        return new udls::YYYLayer(*c);
    } else {
        _PrintErrorStringAndExit("Unknown layer type: " + type);
    }
    return nullptr;
}
}
 
JNIEXPORT jlong JNICALL Java_SnpeController_initContainer(JNIEnv * env,
                                                          jobject arg,
                                                          jstring dlcFilePath)
{
    unique_ptr<zdl::DlContainer::IDlContainer> container;
    string containerPath =  _JstringToChar(env, dlcFilePath);
#ifdef DebugInfo
    cout << "#### container path is :" << containerPath << endl;
#endif
    container = zdl::DlContainer::IDlContainer::open(containerPath);
    if (!container) {
        PrintErrorStringAndExit();
    }
#ifdef DebugInfo
    cout << "#### succeed to load container!" << endl;
    cout << "#### container pointer addr is " << (long)container.get() << endl;
#endif
    zdl::DlContainer::IDlContainer * p = container.release();
#ifdef DebugInfo
    cout << "#### release container unique_ptr!" << endl;
    cout << "#### container common pointer addr is " << (long)p << endl;
#endif
    return (long)p;
}
 
 
JNIEXPORT jlong JNICALL Java_SnpeController_initSnpe(JNIEnv * env,
                                                     jobject arg,
                                                     jlong container,
                                                     jint runtime)
{
    zdl::DlContainer::IDlContainer * tmp = (zdl::DlContainer::IDlContainer *)container;
#ifdef DebugInfo
    cout << "#### function init_snpe get container pointer: " << (long)tmp << endl;
#endif
    if (!zdl::SNPE::SNPEFactory::isRuntimeAvailable(_getRuntime(runtime))) {
        _PrintErrorStringAndExit("runtime is inavailable!!");
    }
    unique_ptr<zdl::SNPE::SNPE> snpe;
    zdl::DlSystem::UDLBundle udlBundle;
    udlBundle.cookie = (void *) 0xdeadbeaf;
    udlBundle.func = udlexample::MyUDLFactory;
    zdl::SNPE::SNPEBuilder snpeBuilder(tmp);
    snpe = snpeBuilder.setOutputLayers({})
          .setRuntimeProcessor(_getRuntime(runtime))
          .setUdlBundle(udlBundle)
          .build();
    if (!snpe)
        PrintErrorStringAndExit();
#ifdef DebugInfo
    cout << "####  succeed to init snpe instance!" << endl;
#endif
    auto logger_opt = snpe->getDiagLogInterface();
    if (!logger_opt)
        _PrintErrorStringAndExit("SNPE failed to obtain logging interface");
    auto logger = *logger_opt;
    auto opts = logger->getOptions();
    //##################################################################################################################
    opts.LogFileDirectory = "/home/mi/snpe_log_dir/";
    if (!logger->setOptions(opts))
        _PrintErrorStringAndExit("Failed to set logger options!");
#ifdef DebugInfo
    cout << "#### succeed to set logger options!!" << endl;
#endif
    if (!logger->start())
        _PrintErrorStringAndExit("Failed to start snpe logger!");
#ifdef DebugInfo
    cout << "#### succeed to start snpe logger!" << endl;
#endif
    string version_string = zdl::SNPE::SNPEFactory::getLibraryVersion().toString();
#ifdef DebugInfo
    cout << "#### snpe version: " << version_string << endl;
#endif
    zdl::SNPE::SNPE * p = snpe.release();
#ifdef DebugInfo
    cout << "#### release snpe unique_ptr!" << endl;
    cout << "#### snpe common pointer addr is " << (long)p << endl;
#endif
    return (long)p;
}
 
 
JNIEXPORT jfloatArray JNICALL Java_SnpeController_exec(JNIEnv * env,
                                                       jobject arg,
                                                       jlong snpe,
                                                       jfloatArray data,
                                                       jint length)
{
    if (data == NULL || length <= 0)
        _PrintErrorStringAndExit("exec function: data illegal!");
    zdl::SNPE::SNPE * tmp_snpe = (zdl::SNPE::SNPE *)snpe;
#ifdef DebugInfo
    cout << "#### function exec get snpe pointer: " << (long)tmp_snpe << endl;
#endif
    if (tmp_snpe == NULL)
        _PrintErrorStringAndExit("exec function: snpe pointer null!");
    zdl::DlSystem::TensorMap outputTensorMap;
    const auto &strList_opt = tmp_snpe->getInputTensorNames();
    if (!strList_opt)
        _PrintErrorStringAndExit("Error obtaining Input tensor names!");
    const auto &strList = *strList_opt;
#ifdef DebugInfo
    cout << "#### have "<< strList.size() << " input tensor(s) to net."<< endl;
    for(int index = 0; index < strList.size(); index++)
        cout << "#### input tensor " << index << "'s name is " << strList.at(index) << endl;
#endif
    const auto &inputDims_opt = tmp_snpe->getInputDimensions(strList.at(0));
    if (!inputDims_opt)
        _PrintErrorStringAndExit("fail to obtain input dimensions!");
    const auto &inputShape = *inputDims_opt;
    size_t inputSize = std::accumulate(inputShape.getDimensions(),
                                       inputShape.getDimensions() + inputShape.rank(),
                                       1,
                                       std::multiplies<size_t>());
#ifdef DebugInfo
    cout << "#### input tensor size of model is: " << inputSize << endl;
#endif
    if (inputSize != length)
        _PrintErrorStringAndExit("input tensor size of model doesn't match image size!");
    const jfloat * inputData = (const jfloat *)env->GetFloatArrayElements(data, JNI_FALSE);
    vector<float> inputVec(inputSize);
    for (int i = 0; i < inputVec.size(); ++i)
        inputVec[i] = inputData[i];
    env->ReleaseFloatArrayElements(data, (jfloat *)inputData, 0);
    unique_ptr<zdl::DlSystem::ITensor> input = zdl::SNPE::SNPEFactory::getTensorFactory().createTensor(inputShape);
    std::copy(inputVec.begin(), inputVec.end(), input->begin());
    if (!tmp_snpe->execute(input.get(), outputTensorMap))
        PrintErrorStringAndExit();
    zdl::DlSystem::StringList tensorNames = outputTensorMap.getTensorNames();
#ifdef DebugInfo
    cout << "#### have " << tensorNames.size() << " output tensor(s) from net." << endl;
#endif
    const char ** outputTensorNamePtr = tensorNames.begin();
    zdl::DlSystem::ITensor * tensorPtr = outputTensorMap.getTensor(*outputTensorNamePtr);
#ifdef DebugInfo
    cout << "#### output tensor size is " << tensorPtr->getSize() << endl;
#endif
    float * tmpOutput = (float *)malloc(sizeof(float) * tensorPtr->getSize());
    jfloatArray outputResult = env->NewFloatArray(tensorPtr->getSize());
    auto p = tensorPtr->cbegin();
    int outputIndex = 0;
    while(p != tensorPtr->cend())
    {
        tmpOutput[outputIndex++] = *p;
//        cout << *p << " ";
        p++;
    }
    env->SetFloatArrayRegion(outputResult, 0, tensorPtr->getSize(), tmpOutput);
#ifdef DebugInfo
    cout << "#### succeed to exec!!!!" << endl;
#endif
    return outputResult;
}
 
 
JNIEXPORT void JNICALL Java_SnpeController_releaseSource(JNIEnv * env,
                                                         jobject arg,
                                                         jlong container,
                                                         jlong snpe)
{
    if (snpe != 0)
        delete (zdl::SNPE::SNPE *)snpe;
    if (container != 0)
        delete (zdl::DlContainer::IDlContainer *)container;
}


猜你喜欢

转载自blog.csdn.net/Song_Esther/article/details/82830913
今日推荐