深度学习移动端在线训练 --- 基于MNN的端侧Finetune实现

  • 在决定使用MNN实现在线训练之前,也比较了TNN/NCNN,发现目前各大端侧推理引擎的训练框架都不成熟,半斤八两的状态,可能都把精力放在推理和op支持上,但是端侧训练的需求真的少么?fine-tune在端侧应用难道不是刚需?
  • 端侧推理的实现相对简单,MNN官方已有完善的文档参考。
  • 这篇主要介绍基于MNN的深度学习端侧在线训练,以预编译动态库的方式引用MNN训练框架,并实现在Android端基于预训练模型的在线finetune全流程。

MNN在线训练思路梳理

  • 在线训练的需求场景(why)

    • 基于预训练模型迁移学习fine-tune(本篇示例场景)
    • 使用MNN进行训练量化
    • 端侧在线学习(如移动端,嵌入式设备)
  • 使用MNN进行在线训练的两种方式(how)

    • 使用MNNConverter将模型转换为可训练模型,端侧加载模型进行训练(本篇示例方式)
    • 端侧利用MNN API从零搭建模型
      • 参考https://www.yuque.com/mnn/cn/kgd9hd
  • 基于MNN实现端侧Finetune所需步骤
    在这里插入图片描述

预训练模型转换成可训练的模型

  • 不同深度学习框架输出的预训练模型需要先转换为mnn格式
    • 参考我的另一篇https://blog.csdn.net/hechao3225/article/details/114820905
  • 将mnn格式的预训练模型转换为可训练模型
    • 开启–forTraining保留BatchNorm,Dropout等训练过程中会用到的算子
    • 如果你的模型中没有BN,Dropout等在转MNN推理模型时会被融合掉的算子,那么直接使用MNN推理模型也可以进行训练,不必重新进行转换
./MNNConvert --modelFile mobilenet_v2_1.0_224_frozen.pb  --MNNModel mobilenet_v2_tfpb_train.mnn --framework TF --bizCode AliNNTest --forTraining

Android平台编译库

  • 在线训练相关编译开关设置

    • 根目录下CMakeList.txt编译开关
      • 打开MNN_BUILD_TRAIN=ON
    • tools/train下CMakeList.txt编译开关
      • MNN_BUILD_TRAIN_MINI=ON : 对于移动端/嵌入式设备,建议设置 MNN_BUILD_TRAIN_MINI = ON,不编译内置的Dataset,Models
      • MNN_USE_OPENCV=OFF : 部分 PC 上的 demo 有用到,暂时用不到
  • 具体步骤

    1. 配置NDK环境变量,指定NDK版本:在 .bashrc 或者 .bash_profile 中设置 NDK 环境变量,

      export ANDROID_NDK=/home/goodix/code/fp_prebuilts/tools/android-ndk-r17b
      
    2. mnn源码根目录执行

      ./schema/generate.sh
      
    3. 进入android目录

      cd project/android
      
    4. 编译armv7动态库

      mkdir build_32 && cd build_32 && ../build_32.sh
      
    5. 编译armv8动态库:

      mkdir build_64 && cd build_64 && ../build_64.sh
      
  • Android平台在线训练依赖的三个库,如果仅作在线推理,只需要引用libMNN.so,如果需要在线训练,还需要libMNN_Express.so和libMNNTrain.so。

    路径如下:

    • project/android/build_64/libMNN.so
    • project/android/build_64/libMNN_Express.so
    • project/android/build_64/tools/train/libMNNTrain.so
  • 官方文档参考:https://www.yuque.com/mnn/cn/build_android

在线迁移学习Finetune Demo实现

  1. 实现微调模型MyAnnTransferModule

    • 对于finetune场景,我们不需要端侧从零搭建模型,只需要加载预训练模型,固定神经网络前面层的参数,仅对全连接层最后一层用于微调

    • 需要继承Module类,并重写构造函数和onForward

      class MyAnnTransferModule : public MNN::Express::Module {
      public:
          AlsAnnTransferModule(const char* fileName);
          virtual std::vector<MNN::Express::VARP> onForward(const std::vector<MNN::Express::VARP>& inputs) override;
          std::shared_ptr<Module> mFixedLayers;
          std::shared_ptr<Module> mFineTuneLayers; // add new layers for finetuning
      };
      
      class MyAnnTransfer{
      public:
          static void train(std::shared_ptr<MNN::Express::Module> model, 
                            std::vector<OneAlsData> trainData, 
                            std::vector<OneAlsData> testData);
      };
      
    • **[关键]**如何固定神经网络部分参数,仅对最后一层或最后几层微调?

      • 通过netron模型可视化工具(或MNNConvert工具输出的模型json文件)查看最后一层的input.name

        • netron.app在线地址:https://netron.app/

        • 示例,open module后点击最后一个Convolution,右边Input name即为最后一层的输入,我们需要通过整个name对模型分界

        在这里插入图片描述

      • 模型构造函数中loadMap加载模型,并使用input.name作为PipelineModule::extract的第二个参数,extract会保留除去最后一层的预训练模型

      • extract保留的部分为固定参数层,新初始化一个layer作为finetune层,仅注册finetune层用于训练,构造函数示例代码:

        MyAnnTransferModule::MyAnnTransferModule(const char* fileName) {
            auto srcModelMap  = Variable::loadMap(fileName);
            auto inputOutputs = Variable::getInputAndOutput(srcModelMap);
            auto input   = inputOutputs.first.begin()->second;
            auto fixedOut = srcModelMap["Reshape45"];
        
            // init a dense layer for finetuning
            // mFineTuneLayers.reset(NN::Linear(10, 5));
            // use a conv layer for a dense layer
            NN::ConvOption option;
            option.channel = {10, 5};
            mFineTuneLayers = std::shared_ptr<Module>(NN::Conv(option));
        
            // get fixed layers from src module, set trainFlag=false
            mFixedLayers.reset(PipelineModule::extract({input}, {fixedOut}, false));
        
            // only train finetuning layers
            registerModel({mFineTuneLayers});
        }
        
      • 然后重写onForward,固定层和微调层先后前向计算,需要注意微调层前向计算后需要进行_Convert和_Reshape操作,Reshape的维度信息仍然可以通过netron可视化工具查看,onForward函数示例代码:

        std::vector<VARP> MyAnnTransferModule::onForward(const std::vector<VARP>& inputs) {
            auto fixedResult = mFixedLayers->forward(inputs[0]);
            auto result = _Reshape(_Convert(mFineTuneLayers->forward(fixedResult), NCHW), {-1, 5});
            return {result};
        }
        
  2. 实现train和test需要的dataset,由dataset创建dataloader

    • 需要继承Dataset类,重写get()和size()两个虚函数,dataset与项目自己的数据格式有关,可参考官方MnistDataset

    • train函数中创建用于train和test的dataloader

      auto dataset = MnistDataset::create(trainData, AlsDataset::Mode::TRAIN);
      const size_t batchSize  = 1; 
      const size_t numWorkers = 0;
      bool shuffle            = true;
      auto dataLoader = std::shared_ptr<DataLoader>(dataset.createLoader(batchSize, true, shuffle, numWorkers));
      size_t iterations = dataLoader->iterNumber();
      
      auto testDataset = MnistDataset::create(testData, AlsDataset::Mode::TEST);
      const size_t testBatchSize  = 1;
      const size_t testNumWorkers = 0;
      shuffle = false;
      
    • 实现train过程(参考MnistUtils.cpp)

      • 示例代码使用sgd优化器,学习率衰减等训练策略,可以根据自己需求使用MNN不同 API测试效果
      • #ifdef DEBUG_GRAD宏包含了梯度校验的代码
          for (int epoch = 0; epoch < 10; ++epoch) {
              model->clearCache();
              exe->gc(Executor::FULL);
              exe->resetProfile();
              {
                  AUTOTIME;
                  dataLoader->reset();
                  model->setIsTraining(true);
                  Timer _100Time;
                  int lastIndex = 0;
                  int moveBatchSize = 0;
                  for (int i = 0; i < iterations; i++) {
                      auto trainData  = dataLoader->next();
                      auto example    = trainData[0];
                      moveBatchSize += example.first[0]->getInfo()->dim[0];
                      auto predict = model->forward(example.first[0]);
                      auto loss    = _MSE(predict, example.second[0]);
      //#define DEBUG_GRAD
      #ifdef DEBUG_GRAD
                      {
                          static bool init = false;
                          if (!init) {
                              init = true;
                              std::set<VARP> para;
                              example.first[0].fix(VARP::INPUT);
                              newTarget.fix(VARP::CONSTANT);
                              auto total = model->parameters();
                              for (auto p :total) {
                                  para.insert(p);
                              }
                              auto grad = OpGrad::grad(loss, para);
                              total.clear();
                              for (auto iter : grad) {
                                  total.emplace_back(iter.second);
                              }
                              Variable::save(total, ".temp.grad");
                          }
                      }
      #endif
                      float rate   = LrScheduler::inv(0.01, epoch * iterations + i, 0.0001, 0.75);
                      sgd->setLearningRate(rate);
                      if (moveBatchSize % (10 * batchSize) == 0 || i == iterations - 1) 
                      {
                          std::cout << "epoch= " << (epoch) << std::endl;
                          std::cout << moveBatchSize << " / " << dataLoader->size() << std::endl;
                          std::cout << " lr= " << rate;
                          std::cout << " time= " << (float)_100Time.durationInUs() / 1000.0f << " ms / " << (i - lastIndex) <<  " iter"  << std::endl;
                          std::cout.flush();
                          _100Time.reset();
                          lastIndex = i;
                      }
                      sgd->step(loss);
                  }
              }
          }
      
    • 保存训练模型

      • 1和2的区别在于前者只保存参数,后者保存模型参数和结构

        // 1. only save model parameters
        Variable::save(model->parameters(), "alsann.snapshot.mnn");
        // 2. save model parameters and structure
        {
        	model->setIsTraining(false);
            auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
            forwardInput->setName("data");
            auto predict = model->forward(forwardInput);
            predict->setName("prob");
            Transformer::turnModelToInfer()->onExecute({predict});
            Variable::save({predict}, "alsann.mnn");
        }
        
      • 如果需要train前加载模型参数:

        // Load snapshot
        auto para = Variable::load("alsann.snapshot.mnn");
        model->loadParameters(para);
        
    • 注:MNN里已经实现了几个image相关的Dataset示例,如ImageDataset,MnistDataset

Finetue Demo编译配置

  • 在main.cpp中创建模型和调用train函数,使用NDK编译成一个独立的二进制程序作为测试demo

    • trainData和testData可以通过参数指定路径的方式传入
    std::shared_ptr<Module> model(new MyAnnTransferModule(argv[1]));
    MyAnnTransferModule::train(model, trainData, testData);
    
  • 我们将MNN编译的三个so以预编译共享库的方式加入Android.mk

    • 定义三个库,MNN,MNN_express和MNN_train

      LOCAL_PATH := $(call my-dir)
      
      include $(CLEAR_VARS)
      LOCAL_MODULE := MNN
      LOCAL_SRC_FILES := ${LOCAL_PATH}/../lib/lib64/libMNN.so
      include $(PREBUILT_SHARED_LIBRARY)
      
      include $(CLEAR_VARS)
      LOCAL_MODULE := MNN_express
      LOCAL_SRC_FILES := ${LOCAL_PATH}/../lib/lib64/libMNN_Express.so
      include $(PREBUILT_SHARED_LIBRARY)
      
      include $(CLEAR_VARS)
      LOCAL_MODULE := MNN_train
      LOCAL_SRC_FILES := ${LOCAL_PATH}/../lib/lib64/libMNNTrain.so
      include $(PREBUILT_SHARED_LIBRARY)
      
    • 引用三个共享库

      LOCAL_SHARED_LIBRARIES := MNN MNN_express MNN_train
      
  • Android.mk添加新增的头文件和源文件

    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/core
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/expr
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/plugin
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/express
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/express/module
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/data
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/grad
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/optimizer
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/parameters
    LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/transformer
    
    CPP_LIST := $(wildcard $(LOCAL_PATH)/*.cpp)
    LOCAL_SRC_FILES := $(CPP_LIST:$(LOCAL_PATH)/%=%)
    
  • 指定NDK版本编译

    /home/goodix/code/fp_prebuilts/tools/android-ndk-r17b/ndk-build   NDK_APPLICATION_MK=Application.mk  APP_BUILD_SCRIPT=Android.mk   APP_ABI=arm64-v8a  NDK_PROJECT_PATH=./   -B
    
  • 编译完成后将模型和demo二进制push到手机测试,push脚本示例

    set -e
    adb root
    adb remount
    adb shell setenforce 0
    # adb shell rm -rf /usr/bin/arm64-v8a
    adb push  obj/local/arm64-v8a/           /usr/bin/
    adb push  /home/goodix/code/als/Hamilton_DL/test_code/mnn_demo/model/my_ann.mnn               /usr/bin/arm64-v8a/my_ann.mnn
    adb shell chmod +x /usr/bin/arm64-v8a/ALS_ANN
    adb shell export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/bin/arm64-v8a:/vendor/lib64:/system/lib64:/system/lib64/vndk-29:/system/lib
    

    遇到的问题

  • 问题1:编译link阶段找不到智能指针相关的定义

    undefined reference to `std::__1::__shared_weak_count::__release_weak()'
    
    • Application.mk已指定C++11以上版本

      APP_CPPFLAGS := -frtti -fexceptions -std=c++14
      
    • 原因是STL的版本必须指定为gnustl_shared,不能是成c++_static,c++_shared, stlport_shared等,这些版本本身不支持shared_ptr和function相关特性。

      在gnustl版本中,shared_ptr定义在NDK根目录\sources\cxx-stl\gnu-libstdc++\4.8\include\memory文件中。

      Application.mk修改

      APP_STL := gnustl_shared
      
    • 但是,ndk-r19c等高版本的STL版本不支持gnustl_shared,还需要切换ndk-17b版本编译

  • 问题2:libMNN_Express.so链接不到函数定义

    • 解决方法:对齐mnn编译脚本和demo编译脚本的stl版本,ndk版本

猜你喜欢

转载自blog.csdn.net/hechao3225/article/details/115731833