深度学习编译中间件之NNVM(十六)NNVM源代码阅读5

参考文档

  1. 深度学习编译中间件之NNVM(十二)NNVM源代码阅读1
  2. 深度学习编译中间件之NNVM(十三)NNVM源代码阅读2
  3. 深度学习编译中间件之NNVM(十四)NNVM源代码阅读3
  4. 深度学习编译中间件之NNVM(十五)NNVM源代码阅读4

NNVM Compiler组件是和使用者比较近的一个组件,本篇文档将详细阅读相关代码。

NNVM Compiler组件中比较重要的函数是nnvm.compiler.build

可以将nnvm.compiler.build的执行过程总结为如下步骤:

  1. 校正Layout
  2. 初始化Pass(指定shape)
  3. 初始化所有变量(_all_var_init)
  4. 应用优化
  5. 预计算裁剪
  6. 融合相邻运算并生成最终so
  7. 保存变量的初始化值到params参数文件中

参考文档1为了快速了解NNVM和TVM是如何交互的,只讲解了步骤6,本文档将介绍所有步骤。

1.校正Layout

python/nnvm/compiler/build_module.py build函数

    # 如果需要时校正Layout
    layout = layout if layout else {} 
    graph = graph_attr.set_layout_inputs(graph, layout)
    graph = graph.apply("CorrectLayout")
    index = graph.index
    layouts = graph.json_attr("layout")
    layout = {x : layouts[index.entry_id(x)] for x in index.input_names}

graph.apply在之前的参考文档1中已经讲解了,对于"CorrectLayout"这个Pass而言会调用每个操作符的FCorrectLayout,操作符的FCorrectLayout函数是由参考文档3里面讲解的NNVM Top组件C++部分定义的。

下面从一个比较简单的操作符max_pool2d入手理解FCorrectLayout的功能。

src/top/nn/pooling.cc

NNVM_REGISTER_OP(max_pool2d)
.set_attr<FCorrectLayout>("FCorrectLayout", Pool2DCorrectLayout)

inline bool Pool2DCorrectLayout(const NodeAttrs& attrs,
                                std::vector<Layout> *ilayouts,
                                const std::vector<Layout> *last_ilayouts,
                                std::vector<Layout> *olayouts) {
    const Pool2DParam &param = nnvm::get<Pool2DParam>(attrs.parsed);
    CHECK_EQ(ilayouts->size(), 1);
    CHECK_EQ(last_ilayouts->size(), 1);
    CHECK_EQ(olayouts->size(), 1);

    Layout input = (*ilayouts)[0];
    const Layout layout(param.layout);

    if (input.defined()) {
        CHECK(input.convertible(layout)) << "Invalid input layout " << input;
        if (input.indexof('W') != layout.indexof('W') ||
            input.indexof('H') != layout.indexof('H') ||
            input.contains('w') || input.contains('h')) {
        // as long as the index doesn't change for width and height
        // pool2d can keep the input layout.
        input = layout;
        }
    } else {
        input = layout;
    }

    NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input);
    NNVM_ASSIGN_LAYOUT(*olayouts, 0, input);

    return true;
}

/*
 * Pool2DCorrectLayout主要功能为:
 * 1. 如果input layout没有指定,则设置成默认layout
 * 2. 如果input layout已经指定,但是和默认layout不一致则校正成默认layout
 */

2.初始化Pass(指定shape)

python/nnvm/compiler/build_module.py build函数

    ishape, _ = graph_util.infer_shape(graph, **shape)
    shape.update(zip(graph.index.input_names, ishape))
    if not isinstance(dtype, str):
        idtype, _ = graph_util.infer_dtype(graph, **dtype)
        dtype.update(zip(graph.index.input_names, idtype))

python/nnvm/compiler/graph_util.py

# infer_shape函数功能为利用提供的输入节点的shape信息计算计算图涉及节点的推理shape
def infer_shape(graph, **shape):
    graph = graph_attr.set_shape_inputs(graph, shape)
    graph = graph.apply("InferShape")
    shape = graph.json_attr("shape")
    index = graph.index
    input_shape = [shape[index.entry_id(x)] for x in index.input_names]
    output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
    return input_shape, output_shape

下面从一个比较简单的操作符max_pool2d入手理解InferShape的功能。

/src/pass/infer_shape_type.cc

// InferShape的主要实现是调用了FInferShape操作符函数
NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.")
.set_body([](Graph ret) {
    return InferAttr<TShape>(
        std::move(ret), TShape(),
        "FInferShape", "shape_inputs", "shape_attr_key",
        "shape", "shape_num_unknown_nodes",
        [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
        nullptr);
  })
.set_change_graph(false)
.provide_graph_attr("shape");

/src/top/nn/pooling.cc

NNVM_REGISTER_OP(max_pool2d)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)

/src/top/elemwise_op_common.h

template<int n_in, int n_out>
inline bool ElemwiseType(const NodeAttrs& attrs,
                         std::vector<int> *in_attrs,
                         std::vector<int> *out_attrs) {
    if (n_in != -1) {
        CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name;
    }
    if (n_out != -1) {
        CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
    }
    return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
        attrs, in_attrs, out_attrs, -1);
}

template<typename AttrType, bool (*is_none)(const AttrType&),
         bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
         std::string (*attr_string)(const AttrType&),
         int n_in = -1, int n_out = -1>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
                         std::vector<AttrType> *in_attrs,
                         std::vector<AttrType> *out_attrs,
                         const AttrType& none) {
    AttrType dattr = none;
    size_t in_size = in_attrs->size();
    size_t out_size = out_attrs->size();
    if (n_in != -1)
        in_size = static_cast<size_t>(n_in);
    if (n_out != -1)
        out_size = static_cast<size_t>(n_out);

    auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
        for (size_t i = 0; i < size; ++i) {
            CHECK(assign(&dattr, (*vec)[i]))
            << "Incompatible attr in node " << attrs.name << " at " << i << "-th "
            << name << ": " << "expected " << attr_string(dattr)
            << ", got " << attr_string((*vec)[i]);
        }
    };
    deduce(in_attrs, in_size, "input");
    if (reverse_infer) deduce(out_attrs, out_size, "output");

    auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
        for (size_t i = 0; i < size; ++i) {
            CHECK(assign(&(*vec)[i], dattr))
            << "Incompatible attr in node " << attrs.name << " at " << i << "-th "
            << name << ": " << "expected " << attr_string(dattr)
            << ", got " << attr_string((*vec)[i]);
        }
    };
    write(in_attrs, in_size, "input");
    write(out_attrs, out_size, "output");

    if (is_none(dattr)) return false;
    return true;
}

3. 初始化所有变量(_all_var_init)

在推测出推理阶段使用的shape参数之后,可以申请变量空间和初始化变量

python/nnvm/compiler/build_module.py build函数

    init_var = {}
    if _all_var_init:
        init_var = initialize_variables(shape, dtype)

def initialize_variables(ishape, idtype):
    """ Initialize variables stored in _all_var_init dictionary.

    Parameters
    ----------
    ishape : dict of str to tuple of int
        The input shape to the graph

    idtype : str or dict of str to str
        The input types to the graph

    Returns
    -------
    init_var : dict of str to tvm.ndarray
    """
    symbol_init_dict = {}
    const_init_dict = {}
    init_var = {}
    for key, value in _all_var_init.items():
        if isinstance(value, sym.Symbol):
            symbol_init_dict[key] = value
        else:
            const_init_dict[key] = tvm.nd.array(value)
    # Make sure variables are initialized only once.
    _all_var_init.clear()
    if symbol_init_dict:
        # Create dummy params to run initialization graph
        params = {}
        for name, shape in ishape.items():
            dtype = idtype if isinstance(idtype, str) else idtype[name]
            params[name] = tvm.nd.empty(shape, dtype, ctx=tvm.cpu())
        init_group_sym = sym.Group(symbol_init_dict.values())
        graph = _graph.create(init_group_sym)
        with tvm.build_config(auto_unroll_max_step=0):
            init_values = _run_graph(graph, params)
        init_var.update(dict(zip(symbol_init_dict.keys(), init_values)))
    init_var.update(const_init_dict)
    for name, data in init_var.items():
        ishape[name] = data.shape
    return init_var

4. 应用优化

graph = optimize(graph, shape, dtype, layout)

optimize函数主要针对计算图应用图优化相关的pass:

  • SimplifyInference
  • FoldScaleAxis

5. 预计算裁剪

预计算裁剪的主要功能是通过预计算整个计算图的一部分,然后排除掉一些和前向推理无关的计算节点

6. 融合相邻运算并生成最终so

参考文档1已经讲解了这个步骤的一部分知识,这里将进行继续介绍两个pass:

  • GraphFusePartition
  • GraphFuseCompile

GraphFusePartition的功能主要是将可以fuse的节点放到一个segment中,以供后面编译使用。

GraphFuseCompile是进行lowering编译的过程,其中调用了nnvm.compiler.lower

nnvm.compiler.lower的定义位于tvm/python/tvm/build_module.py

# 代码节选
def lower(sch,
          args,
          name="default_function",
          binds=None,
          simple_mode=False):
    """
    Parameters
    ----------
    sch : tvm.Schedule
        需要被编译的调度器
    """
    ...
    # normalize schedule first
    sch = sch.normalize()
    # Phase 0
    bounds = schedule.InferBound(sch)
    stmt = schedule.ScheduleOps(sch, bounds)
    stmt = ir_pass.InjectPrefetch(stmt)
    ...

lower接口中的参数sch的类型为tvm.Schedule,这里sch是由nnvm Top组件和TOPI组件一起决定的。

lower的具体实现中值得注意的是schedule.ScheduleOps这个函数,利用sch生成HalideIR::Internal::Stmt表达式。

tvm/src/api/api_schedule.cc

TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  if (args.size() == 2)
    *ret = ScheduleOps(args[0], args[1], false);
  else
    *ret = ScheduleOps(args[0], args[1], args[2]);
});

tvm/src/schedule/schedule_ops.cc

Stmt ScheduleOps(Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
    Stmt body = Stmt();
    std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
    // scan init and scan updates
    std::unordered_map<Operation, Operation> scan_init;
    for (Stage s : sch->stages) {
        const ScanOpNode* scan = s->op.as<ScanOpNode>();
        if (!scan) continue;
        for (Tensor t : scan->init) {
            if (scan_init.count(t->op)) {
                CHECK(scan_init.at(t->op).same_as(s->op))
                    << "Scan init tensor can only belong to one scan";
            } else {
                scan_init[t->op] = s->op;
            }
        }
    }
    // 确认group的正确性.
    for (Stage g : sch->groups) {
        CHECK(!g->op.defined());
        CHECK_EQ(g->leaf_iter_vars.size(), 0U);
    }
    // reverse the post DFS order.
    for (size_t i = sch->stages.size(); i != 0; --i) {
        Stage s = sch->stages[i - 1];
        CHECK_NE(s->attach_type, kInline)
            << "call schedule.normalize before scheduleops";
        CHECK(s->op.defined());
        // no need to specify place holder op.
        if (s->op.as<PlaceholderOpNode>()) continue;
        // Remove grouping sugar, get the real attach spec.
        Stage attach_spec = s.GetAttachSpec();

        if (scan_init.count(s->op)) {
        CHECK(body.defined());
        InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
        body = mu.Mutate(body);
        CHECK(mu.found_attach)
            << "did not find attachment point for scan.init";
        } else if (attach_spec->attach_type == kScanUpdate) {
        // Handle scan update
        CHECK(body.defined());
        InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
        body = mu.Mutate(body);
        CHECK(mu.found_attach)
            << "did not find attachment point for scan.update";
        } else if (attach_spec->attach_type == kInlinedAlready) {
        // do nothing
        } else if (attach_spec->attach_type == kGroupRoot) {
        CHECK(!s->group.defined());
        body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
        } else {
        CHECK_EQ(attach_spec->attach_type, kScope);
        CHECK(body.defined());
        InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
        body = mutator.Mutate(body);
        CHECK(mutator.found_attach)
            << "did not find attachment point for " << s << " in "
            << attach_spec->attach_stage->op  << " x " << attach_spec->attach_ivar
            << ", body:\n"
            << body;
        }
    }
    SchedulePostProc post_proc;
    post_proc.Init(sch);
    return post_proc.Mutate(body);
    }

猜你喜欢

转载自blog.csdn.net/sanallen/article/details/80315954