深度学习编译中间件之NNVM(十三)NNVM源代码阅读2

参考文档

  1. 深度学习编译中间件之NNVM(十二)NNVM源代码阅读1

本系列文档涉及NNVM源代码阅读理解,本篇主要介绍一些NNVM的基础数据结构。

使用的C++命令空间为nnvm

相关代码位于
1. include/nnvm
2. src/core

class Op

代码位于

  • include/nnvm/op.h
  • include/nnvm/op_attr_types.h
  • src/core/op.cc

Op类主要用于记录操作符的一些信息

// 代码只是节选
class NNVM_DLL Op {
public:
    std::string name; // 操作符名称
    std::string description; // 操作符详细解释,可用于文档生成
    std::vector<ParamFieldInfo> arguments; // 带文字描述的参数数组
    uint32_t num_inputs = 1; // 操作符的输入数据个数
    uint32_t num_outputs = 1; // 操作符的输出数据个数
    uint32_t support_level = 10; // 支持优先级,数字越小越优先

    std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
    std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;

    std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;


    inline Op& describe(const std::string& descr);
    inline Op& add_argument(const std::string &name,
                            const std::string &type,
                            const std::string &description);
    inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);

    inline Op& set_num_inputs(uint32_t n); 
    inline Op& set_support_level(uint32_t level);
    inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn);  
    inline Op& set_num_outputs(uint32_t n);
    inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn);
    inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn);

    template<typename ValueType>
    inline Op& set_attr(const std::string& attr_name, 
                        const ValueType& value,
                        int plevel = 10);
    Op& add_alias(const std::string& alias); 
    Op& include(const std::string& group_name);
    static const Op* Get(const std::string& op_name);

    template<typename ValueType>
    static const OpMap<ValueType>& GetAttr(const std::string& attr_name);

private:
    template<typename ValueType>
    friend class OpMap;
    friend class OpGroup;
    friend class dmlc::Registry<Op>;

    uint32_t index_{0}; // 唯一操作符索引,用于OpManager区分Op
};

另外include/nnvm/op_attr_types.h中提供了操作符支持的属性类型定义

// 代码节选
using FInferShape = FInferNodeEntryAttr<TShape>;
using FInferType = FInferNodeEntryAttr<int>;

// 得到操作符节点的梯度节点,这个函数用于生成反向传播计算图
using FGradient = std::function<std::vector<NodeEntry>(
    const NodePtr& nodeptr,
    const std::vector<NodeEntry>& out_grads)>;
...

class Node

代码位于

  • include/nnvm/node.h
  • src/core/node.cc

Node类用于在一个计算图中表示一个操作

class NNVM_DLL Node {
public:
    NodeAttrs attrs; // 节点属性
    std::vector<NodeEntry> inputs; // 节点输入向量 
    std::vector<NodePtr> control_deps; // 依赖节点,用于控制流依赖
    any info; // 节点额外信息

    inline const Op* op() const; // 返回节点包含的操作
    inline bool is_variable() const; // 判断节点是否是占位变量(即节点内不含操作,节点的作用只是占用)
    inline uint32_t num_outputs() const; 
    inline uint32_t num_inputs() const;
    static NodePtr Create(); // 创建一个空节点(静态方法)
};

class Graph

代码位于

  • include/nnvm/graph.h
  • src/core/graph.cc

Graph类用于表示一个计算图,它是一个为了进行优化Pass的中间表示。

class Graph {
public:
    std::vector<NodeEntry> outputs; // 计算图的输出节点
    std::unordered_map<std::string, std::shared_ptr<any> > attrs; // 计算图属性集合

    template<typename T>
    inline const T& GetAttr(const std::string& attr_name) const;
    inline bool HasAttr(const std::string& attr_name) const;
    template<typename T>
    inline T MoveCopyAttr(const std::string& attr_name);
    const IndexedGraph& indexed_graph() const; // 获取当前计算图的索引图,如果不存在就按需创建
private:
    mutable std::shared_ptr<const IndexedGraph> indexed_graph_;
}

// 下面介绍Graph的辅助类IndexedGraph

/*!
 * IndexedGraph用于提供索引一个计算图的辅助数据结构
 * 它将图内部的节点们映射到一个连续整型变量node_id,而且将输出节点映射到一个连续整型变量entry_id。
 * 这样的方式允许将计算图的内部节点和输出节点存储在一个紧凑的向量结构中,并且可以做到快速存取
 * 节点的node_id和entry_id是和保存的JSON文件的顺序是一致的
 */
class IndexedGraph {
public:
    /* 表示计算图中的一个数据 */
    struct NodeEntry {
        uint32_t node_id;
        uint32_t index;
        uint32_t version;
    };
    /* 表示计算图中的一个节点 */
    struct Node {
        const nnvm::Node* source; // 指向源节点的指针
        array_view<NodeEntry> inputs; // 节点的输入数据
        array_view<uint32_t> control_deps;
        std::weak_ptr<nnvm::Node> weak_ref; // 指向节点的弱引用
    };

    inline uint32_t entry_id(uint32_t node_id, uint32_t index); // 获取一个唯一的entry_id
    inline uint32_t entry_id(const NodeEntry& e); 

    inline const std::vector<uint32_t>& input_nodes(); // 返回argument节点列表
private:
    friend class Graph;

}

class PassFunctionReg

代码位于

  • include/nnvm/pass.h
  • src/core/pass.cc

PassFunctionReg类为DataIterator工厂函数提供注册入口

// PassFunctionReg继承自dmlc::FunctionRegEntryBase,这个类主要用于函数注册。
// PassFunctionReg在FunctionRegEntryBase类注册普通函数的基础上增加和Pass相关的属性和函数

struct PassFunctionReg
    : public dmlc::FunctionRegEntryBase<PassFunctionReg,
                                        PassFunction> {
    bool change_graph{false}; // 标记pass是否会改变计算图的结构
    std::vector<std::string> graph_attr_dependency; // 记录pass在被应用之前哪些计算图属性必须处于可用
    std::vector<std::string> graph_attr_targets; // 记录pass在被应用之后将生成哪些计算图属性
}

// 下面介绍一些辅助数据结构和函数

/*!
 * \brief 一个PassFunction表示一个针对计算图所做的操作
 * 这个函数处理一个源计算图,返回一个目标计算图,这两个计算图可能一致也可能不一致
 * 一个PassFunction可能会改变图结构,也可能会增加图属性
 */
typedef std::function<Graph (Graph src)> PassFunction;

// 针对输入计算图应用一系列pass
Graph ApplyPasses(Graph src, const std::vector<std::string>& passes);

class Symbol

代码位于

  • include/nnvm/symbolic.h
  • src/core/symbolic.cc

Symbol类是一个帮助类,用于表示计算图中的操作节点。

Symbol类拥有一个利用Group/Functor/Variable这些组件来创建计算图的接口,Symbol类也会被导出到NNVM的Python前端,用于方便进行快速测试和部署。后面将有专门的文档讲解NNVM的Python接口的部分。

// 代码节选
class NNVM_DLL Symbol {
public:
    std::vector<NodeEntry> outputs;

    Symbol Copy() const;
    void Print(std::ostream &os) const;

    std::vector<NodePtr> ListInputs(ListInputOption option) const;
    std::vector<std::string> ListInputNames(ListInputOption option) const;
    std::vector<std::string> ListOutputNames() const;

    // 创建Symbol/Variable/Group Symbol
    static Symbol CreateFunctor(const Op* op,
                                std::unordered_map<std::string, std::string> attrs);
    static Symbol CreateFunctor(const NodeAttrs& attrs);
    static Symbol CreateVariable(const std::string& name);
    static Symbol CreateGroup(const std::vector<Symbol>& symbols);
}

class Layout

代码位于

  • include/nnvm/layout.h

Layout类用于处理Layout表达式

layout由大写字母、小写字母和数字组成,其中大写字母表示一个维度,大写字母对应的小写字母表示一个split之后的子维度,小写字母之前的数字则表示split块的数量。

例如:NCHW16c

表示:[batch_size, channel, height, width, channel_block], channel_block=16

至此NNVM的基础数据结构就介绍完了,接下来的文档将会具体分析NNVM的重要组件

猜你喜欢

转载自blog.csdn.net/sanallen/article/details/80181249