参考文档
- 深度学习编译中间件之NNVM(十二)NNVM源代码阅读1
- 深度学习编译中间件之NNVM(十三)NNVM源代码阅读2
- 深度学习编译中间件之NNVM(十四)NNVM源代码阅读3
- 深度学习编译中间件之NNVM(十五)NNVM源代码阅读4
- 深度学习编译中间件之NNVM(十六)NNVM源代码阅读5
这篇文档将讲解和HalideIR相关的内容。
HalideIR是一个创建符号表达式和算术简化的基础模块。它从原始的Halide项目重构而来,用于TVM项目中。
工程结构
HalideIR的代码基于Halide(release_2017_05_03),由四个部分组成:
- tvm:TVM封装器代码,用于基础数据结构
- base:基础类型和工具
- ir:IR数据结构
- arithmetic:算术简化
base部分
base部分具体提供了哪些功能:
- 代码生成调试(可选)
- 编译时错误与异常处理
- float16实现(去除LLVM依赖)
- halide基础类型定义
- halide实用工具函数定义
接下来着重讲解halide基础类型和实用工具这两个内容。
base部分将一系列类型表示为C++函数签名,这种形式拥有两个优点:
- 可以为Halide函数提供正确的原型,提供更好的编译时类型校验
- C++命名编码能为Halide函数和外部调用函数提供链接时类型校验
base部分还提供了一些实用函数
- extract_namespaces
- add_would_overflow:加法数值溢出判断
- sub_would_overflow:减法数值溢出判断
- mul_would_overflow:乘法数值溢出判断
tvm部分
tvm部分比较重要的数据结构有:
- tvm::Node
- tvm::NodeRef
- tvm::ArrayNode(在DSL计算图中使用)
- tvm::MapNode(在DSL计算图中使用)
- tvm::IRFunctor
// 代码节选
class EXPORT Node : public std::enable_shared_from_this<Node> {
public:
virtual const char* type_key() const = 0;
virtual void VisitAttrs(AttrVisitor* visitor) {}
}
/*! NOdeRef是所有节点引用对象的基类 */
class NodeRef {
using ContainerType = Node;
inline bool operator==(const NodeRef& other) const;
inline bool same_as(const NodeRef& other) const;
inline bool operator<(const NodeRef& other) const;
inline bool operator!=(const NodeRef& other) const;
inline uint32_t type_index() const;
inline const Node* operator->() const;
template<typename T>
inline const T *as() const;
NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node> node) : node_(node) {}
std::shared_ptr<Node> node_;
}
ir部分
tvm/HalideIR/ir/Expr.h
/** 一个处理statement node的引用计数的handle */
struct Stmt : public IRHandle {
Stmt() : IRHandle() {}
Stmt(std::shared_ptr<IR::Node> n) : IRHandle(n) {}
/** Dispatch to the correct visitor method for this node. E.g. if
* this node is actually an Add node, then this will call
* IRVisitor::visit(const Add *) */
inline void accept(Internal::IRVisitor *v) const {
static_cast<const Internal::BaseStmtNode *>(node_.get())->accept(v, *this);
}
/*! \brief type indicate the container type */
using ContainerType = Internal::BaseStmtNode;
};
IR.h里面存放了深度学习需要的IR基础节点
- 固定值(IntImm/UIntImm/FloatImm/StringImm)
- 二进制算数运算(Add/Sub/Mul/Div/Mod/Min/Max)
- 比较运算符()
- 逻辑运算符(And/Or/Not)
- 选择运算符(Select) *
tvm/HalideIR/ir/IR.cpp