mxnet代码理解 —— c_predict_api

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/yangjf91/article/details/84097913

c_predict_api

  _CreateExecutor对模型预测句柄连接引擎:

inline void _CreateExecutor(PredictorHandle pred_hnd) 
{
  MXAPIPredictor *pred = static_cast<MXAPIPredictor*>(pred_hnd);
  if (pred->exec == nullptr) 
  {
    auto sym = pred->sym;
    auto ctx = pred->ctx;
    auto key2arg = pred->key2arg;
    auto arg_arrays = pred->arg_arrays;
    auto aux_arrays = pred->aux_arrays;
    //剩余out_arrays、out_shapes、out_shapes_buffer
    std::map<std::string, Context> ctx_map;
    std::vector<NDArray> grad_store(arg_arrays.size());
    std::vector<OpReqType> grad_req(arg_arrays.size(), kNullOp);
    pred->exec.reset(Executor::Bind(sym, ctx, ctx_map, arg_arrays, grad_store, grad_req, aux_arrays));//释放ret原空间并指向Bind结果
    pred->out_arrays = pred->exec->outputs();
  }
}

  static_cast用法说明
  static_cast是C++强制类型转换操作符,将void*的pred_hnd转换为MXAPIPredictor指针类型。
  _CreatePartialOut创建预测句柄:

int _CreatePartialOut(const char* symbol_json_str,
                      const void* param_bytes,
                      int param_size,
                      int dev_type, int dev_id,
                      mx_uint num_input_nodes,
                      const char** input_keys,
                      const mx_uint* input_shape_indptr,
                      const mx_uint* input_shape_data,
                      mx_uint num_output_nodes,
                      const char** output_keys,
                      // This is used for parallel inference.
                      int num_threads,
                      bool lazy,
                      PredictorHandle* out) 
{
  using nnvm::Symbol;
  API_BEGIN();//异常检测及状态检测
  Symbol sym;  
  {
  	mx_uint outSize;
  	const char **outArray;
  	MXListAllOpNames(&outSize, &outArray);
  }// 将mxnet函数转到nnvm中  
  
  {
    nnvm::Graph g;
    g.attrs["json"] = std::make_shared<nnvm::any>(std::string(symbol_json_str));
    sym.outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
  }// 载入json文件,内含定义的网络结构.
  
  if (num_output_nodes != 0) {
    Symbol internal = sym.GetInternals();
    std::vector<std::string> all_out = internal.ListOutputNames();
    std::vector<Symbol> out_syms(num_output_nodes);
    for (mx_uint i = 0; i < num_output_nodes; ++i) {
      std::string out_key(output_keys[i]);
      out_key += "_output";
      for (size_t j = 0; j < all_out.size(); ++j) {
        if (all_out[j] == out_key) {
          out_syms[i] = internal[j];
          break;
        }
        CHECK_NE(j, all_out.size() - 1) << "didn't find node name: " << out_key;
      }
    }
    sym = nnvm::Symbol::CreateGroup(out_syms);
  }//检查节点名称是否正确

  std::unordered_map<std::string, NDArray> arg_params, aux_params;
  {
    std::unordered_set<std::string> arg_names, aux_names;
    std::vector<std::string> arg_names_vec = sym.ListInputNames(Symbol::kReadOnlyArgs);
    std::vector<std::string> aux_names_vec = sym.ListInputNames(Symbol::kAuxiliaryStates);
    for (size_t i = 0; i < arg_names_vec.size(); ++i) {
      arg_names.insert(arg_names_vec[i]);
    }
    for (size_t i = 0; i < aux_names_vec.size(); ++i) {
      aux_names.insert(aux_names_vec[i]);
    }
    std::vector<NDArray> data;
    std::vector<std::string> names;
    dmlc::MemoryFixedSizeStream fi((void*)param_bytes, param_size);  // NOLINT(*)
    NDArray::Load(&fi, &data, &names);
    CHECK_EQ(names.size(), data.size())
        << "Invalid param file format";
    for (size_t i = 0; i < names.size(); ++i) {
      if (!strncmp(names[i].c_str(), "aux:", 4)) {
        std::string name(names[i].c_str() + 4);
        if (aux_names.count(name) != 0) {
          aux_params[name] = data[i];
        }
      }
      if (!strncmp(names[i].c_str(), "arg:", 4)) {
        std::string name(names[i].c_str() + 4);
        if (arg_names.count(name) != 0) {
          arg_params[name] = data[i];
        }
      }
    }
  }//根据json中定义的网络名称,在param中找到对应的参数并载入

  std::unordered_map<std::string, TShape> known_shape;
  for (mx_uint i = 0; i < num_input_nodes; ++i) {
    known_shape[std::string(input_keys[i])] =
        TShape(input_shape_data + input_shape_indptr[i],
               input_shape_data + input_shape_indptr[i + 1]);
  }//获取每个输出节点的尺寸
  std::vector<std::string> arg_names = sym.ListInputNames(Symbol::kReadOnlyArgs);
  std::vector<std::string> aux_names = sym.ListInputNames(Symbol::kAuxiliaryStates);
  std::vector<TShape> out_shapes(sym.ListOutputNames().size());
  std::vector<TShape> aux_shapes(aux_names.size());
  std::vector<TShape> arg_shapes;
  std::unordered_map<std::string, size_t> key2arg;
  for (size_t i = 0; i < arg_names.size(); ++i) {
    std::string key = arg_names[i];
    key2arg[key] = i;
  }//对每个输出节点编号

  try {
    std::vector<TShape> in_shapes;
    for (std::string key : sym.ListInputNames(Symbol::kAll)) {
      if (known_shape.count(key) != 0) {
        in_shapes.push_back(known_shape[key]);
      } else {
        in_shapes.emplace_back();
      }
    }
    nnvm::Graph g; g.outputs = sym.outputs;
    g = mxnet::exec::InferShape(std::move(g), std::move(in_shapes), "__shape__");
    bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
    CHECK(infer_complete) << "The shape information of is not enough to get the shapes";
    CopyAttr(g.indexed_graph(),
             g.GetAttr<nnvm::ShapeVector>("shape"),
             &arg_shapes, &out_shapes, &aux_shapes);
  } catch (const mxnet::op::InferShapeError &err) {
    throw dmlc::Error(err.msg);
  }//没有未知规模的节点

  Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);//创建上下文

  std::vector<NDArray> arg_arrays, aux_arrays;
  for (size_t i = 0; i < arg_shapes.size(); ++i) {
    NDArray nd = NDArray(arg_shapes[i], ctx);
    if (arg_params.count(arg_names[i]) != 0) {
      CopyFromTo(arg_params[arg_names[i]], &nd);
    }
    arg_arrays.push_back(nd);
  }
  for (size_t i = 0; i < aux_shapes.size(); ++i) {
    NDArray nd = NDArray(aux_shapes[i], ctx);
    if (aux_params.count(aux_names[i]) != 0) {
      CopyFromTo(aux_params[aux_names[i]], &nd);
    }
    aux_arrays.push_back(nd);
  }//将通过param文件读取到的参数拷贝到规定设备的NDArray中
  
  for (int i = 0; i < num_threads; i++) {
    std::unique_ptr<MXAPIPredictor> ret(new MXAPIPredictor());
    ret->sym = sym;
    ret->ctx = ctx;
    ret->key2arg = key2arg;
    ret->arg_arrays = arg_arrays;
    ret->aux_arrays = aux_arrays;
    ret->out_shapes = out_shapes;
    if (!lazy) {
      std::map<std::string, Context> ctx_map;
      std::vector<NDArray> grad_store(arg_arrays.size());
      std::vector<OpReqType> grad_req(arg_arrays.size(), kNullOp);
      ret->exec.reset(Executor::Bind(sym, ctx, ctx_map,
                                     arg_arrays,
                                     grad_store, grad_req,
                                     aux_arrays));//释放ret原空间并指向Bind结果
      ret->out_arrays = ret->exec->outputs();
    }//是否直接建立连接
    out[i] = ret.release();//将所有权传递给out[i]
  }
  API_END_HANDLE_ERROR();//异常检测结束
}

  通过std::move提高参数赋值效率,通过unique_ptr的release、reset管理内存。

  MXPredCreatePartialOut为创建mxnet自己管理多线程的预测句柄:

int MXPredCreatePartialOut(const char* symbol_json_str,
                           const void* param_bytes,
                           int param_size,
                           int dev_type, int dev_id,
                           mx_uint num_input_nodes,
                           const char** input_keys,
                           const mx_uint* input_shape_indptr,
                           const mx_uint* input_shape_data,
                           mx_uint num_output_nodes,
                           const char** output_keys,
                           PredictorHandle* out) {
  return _CreatePartialOut(
      symbol_json_str,
      param_bytes,
      param_size,
      dev_type, dev_id,
      num_input_nodes,
      input_keys,
      input_shape_indptr,
      input_shape_data,
      num_output_nodes,
      output_keys,
      1,
      false,//在创建时bind执行器
      out);
}

  与其相对的是创建mxnet在单线程下使用的句柄,可在多个自己的线程中运行。

int MXPredCreateMultiThread(const char* symbol_json_str,
                            const void* param_bytes,
                            int param_size,
                            int dev_type, int dev_id,
                            mx_uint num_input_nodes,
                            const char** input_keys,
                            const mx_uint* input_shape_indptr,
                            const mx_uint* input_shape_data,
	                        mx_uint num_output_nodes,
	                        const char** output_keys,
                            // This is used for paralle inference.
                            int num_threads,
                            PredictorHandle* out) {
  return _CreatePartialOut(
      symbol_json_str,
      param_bytes,
      param_size,
      dev_type,
      dev_id,
      num_input_nodes,
      input_keys,
      input_shape_indptr,
      input_shape_data,
	  num_output_nodes,
	  output_keys,
      num_threads,
      true,//在创建时不bind执行器,而在第一次forward时bind
      out);
}

  MXPredReshape是句柄尺寸调整:

int MXPredReshape(mx_uint num_input_nodes,
                  const char** input_keys,
                  const mx_uint* input_shape_indptr,
                  const mx_uint* input_shape_data,
                  PredictorHandle handle,
                  PredictorHandle* out) {
  _CreateExecutor(handle);//bind,确保handle有效,可以通过盖变量使用和比较
  MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
  std::unique_ptr<MXAPIPredictor> ret(new MXAPIPredictor());

  API_BEGIN();
  // shape inference
  std::unordered_map<std::string, TShape> new_shape;
  for (mx_uint i = 0; i < num_input_nodes; ++i) {
    new_shape[std::string(input_keys[i])] =
        TShape(input_shape_data + input_shape_indptr[i],
            input_shape_data + input_shape_indptr[i + 1]);
  }//设置新模型输入尺度
  ret->sym = p->sym;//sym = nnvm::Symbol::CreateGroup(out_syms);
  std::vector<std::string> arg_names = ret->sym.ListInputNames(Symbol::kReadOnlyArgs);
  std::vector<std::string> aux_names = ret->sym.ListInputNames(Symbol::kAuxiliaryStates);
  std::vector<TShape> out_shapes(ret->sym.ListOutputNames().size());
  std::vector<TShape> aux_shapes(aux_names.size());
  std::vector<TShape> arg_shapes;
  ret->key2arg = p->key2arg;//获取原始模型中的网络结构

  try {
    std::vector<TShape> in_shapes;
    in_shapes.reserve(arg_names.size());
    for (std::string key : ret->sym.ListInputNames(Symbol::kAll)) {
      if (new_shape.count(key) != 0) {
        in_shapes.push_back(new_shape[key]);
      } else {
        in_shapes.emplace_back();
      }
    }
    nnvm::Graph g; g.outputs = ret->sym.outputs;
    g = mxnet::exec::InferShape(std::move(g), std::move(in_shapes), "__shape__");
    bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
    CHECK(infer_complete) << "The shape information of is not enough to get the shapes";
    CopyAttr(g.indexed_graph(),
             g.GetAttr<nnvm::ShapeVector>("shape"),
             &arg_shapes, &out_shapes, &aux_shapes);
  } catch (const mxnet::op::InferShapeError &err) {
    throw dmlc::Error(err.msg);
  }

  ret->arg_arrays = p->arg_arrays;
  ret->ctx = p->ctx;
  for (size_t i=0; i < arg_names.size(); ++i) {
    TShape newShape = arg_shapes[i];
    NDArray &arr = p->arg_arrays[i];
    if (new_shape.count(arg_names[i]) != 0) {
      ret->arg_arrays[i].ReshapeAndAlloc(newShape);
    } else {
       CHECK_EQ(newShape.Size(), arr.shape().Size())
        << "arg " << arg_names[i]
        << " shape has been changed, only allow to change the shape of input data.";
    }
  }

  for (size_t i=0; i < aux_names.size(); ++i) {
    TShape newShape = aux_shapes[i];
    NDArray &arr = p->aux_arrays[i];
    CHECK_EQ(newShape.Size(), arr.shape().Size())
      << "aux " << aux_names[i]
      << " shape has been changed, only allow to change the shape of input data.";
  }
  ret->aux_arrays = p->aux_arrays;

  // bind,可以看出,Reshape必须在子线程中进行
  {
    std::map<std::string, Context> ctx_map;
    std::vector<NDArray> grad_store;
    grad_store.reserve(ret->arg_arrays.size());
    std::vector<OpReqType> grad_req(ret->arg_arrays.size(), kNullOp);

    ret->exec.reset(Executor::Bind(ret->sym, ret->ctx, ctx_map,
                                   ret->arg_arrays,
                                   grad_store, grad_req,
                                   ret->aux_arrays,
                                   p->exec.get()));
    ret->out_shapes = out_shapes;
    ret->out_arrays = ret->exec->outputs();
  }
  *out = ret.release();
  API_END();
}

  MXPredGetOutputShape获取输出的数据规模:

int MXPredGetOutputShape(PredictorHandle handle,
                         mx_uint out_index,
                         mx_uint** shape_data,
                         mx_uint* shape_ndim) {
  MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
  API_BEGIN();
  CHECK_LT(out_index, p->out_arrays.size()) << "Index exceed number of outputs";
  const TShape& s = p->out_shapes[out_index];//从tshape中获取尺度,放到可读的buffer中
  p->out_shapes_buffer.resize(s.ndim());
  nnvm::ShapeTypeCast(s.begin(), s.end(), p->out_shapes_buffer.data());
  *shape_data = p->out_shapes_buffer.data();
  *shape_ndim = p->out_shapes[out_index].ndim();
  API_END();
}

  MXPredSetInput载入输入数据:

int MXPredSetInput(PredictorHandle handle,
                   const char* key,
                   const mx_float* data,
                   mx_uint size) {
  MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
  API_BEGIN();
  auto it = p->key2arg.find(key);
  if (it == p->key2arg.end()) {
    LOG(FATAL) << "cannot find input key " << key;
  }
  NDArray& nd = p->arg_arrays[it->second];
  nd.SyncCopyFromCPU(data, size);//SyncCopyFromCPU会调用WaitToWrite,待详细分析
  API_END();
}

  MXPredForward为前向运算,MXPredPartialForward为指定步骤运行,重点用MXPredForward:

int MXPredPartialForward(PredictorHandle handle, int step, int* step_left) {
  _CreateExecutor(handle);
  MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
  API_BEGIN();
  p->exec->PartialForward(false, step, step_left);//执行graph_executor中的RunOps,待详细分析
  API_END();
}

  MXPredGetOutput获取输出结果

int MXPredGetOutput(PredictorHandle handle,
                    mx_uint index,
                    mx_float* data,
                    mx_uint size) {
  MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
  API_BEGIN();
  CHECK_LT(index, p->out_arrays.size()) << "Output index out of range";
  const NDArray& nd = p->out_arrays[index];
  nd.SyncCopyToCPU(data, size);//SyncCopyToCPU会调用WaitToRead,待详细分析
  API_END();
}

  MXPredFree在模型使用完毕后回收内存空间

int MXPredFree(PredictorHandle handle) {
  API_BEGIN();
  delete static_cast<MXAPIPredictor*>(handle);
  API_END();
}

  

  

  

猜你喜欢

转载自blog.csdn.net/yangjf91/article/details/84097913