深度学习编译中间件之NNVM(十五)NNVM源代码阅读4

参考文档

  1. 深度学习编译中间件之NNVM(十二)NNVM源代码阅读1
  2. 深度学习编译中间件之NNVM(十三)NNVM源代码阅读2
  3. 深度学习编译中间件之NNVM(十四)NNVM源代码阅读3

NNVM Frontend组件主要负责将多种深度学习框架训练出来的模型转换成如下内容:

  1. nnvm.Graph对象:用于存储模型网络描述
  2. tvm.nd.Array对象:用于存储模型权重参数

NNVM Frontend组件将不同深度学习框架的模型格式统一转换成nnvm.Graph和tvm.nd.array的组合。

本篇文档暂时先只关注nnvm.Graph对象和mxnet模型转换。

相关代码位于:

  • python/nnvm/frontend/common.py
  • python/nnvm/frontend/mxnet.py

mxnet模型加载与转换的接口函数为nnvm.frontend.from_mxnet

在介绍转换接口函数前先了解一下nnvm.Graph这个数据结构,nnvm.Graph定义于python/nnvm/graph.py:

nnvm.Graph用来表示一个graph对象,这个对象可以被用于应用优化pass。它包含了额外的一些计算图级别专用的属性。

class Graph(object):
    def json_attr(self, key) # 获取属性字符串
    def _set_json_attr(self, key, value, type_name=None) # 设置属性
    def json(self) # 获取计算图的json表示
    def _tvm_graph_json(self) # 获取TVM计算图的json表示
    def ir(self, join_entry_attrs=None, join_node_attrs=None) # 获取计算图IR的文本形式
    def apply(self, passes) # 针对某个graph应用pass

Graph对象比较重要的一个函数是apply,具体是通过调用NNGraphApplyPasses来实现。

接下来介绍一下mxnet模型的具体转换过程。

python/nnvm/frontend/mxnet.py

def _convert_symbol(op_name, inputs, attrs,
                    identity_list=None,
                    convert_map=None):
    identity_list = identity_list if identity_list else _identity_list
    convert_map = convert_map if convert_map else _convert_map
    if op_name in identity_list:
        op = _get_nnvm_op(op_name)
        sym = op(*inputs, **attrs)
    elif op_name in convert_map:
        sym = convert_map[op_name](inputs, attrs)
    else:
        _raise_not_supported('Operator: ' + op_name)
    return sym

针对单个运算符的转换过程主要由_convert_symbol函数完成,其中涉及到两个运算符列表

  • _identity_list:表示mxnet运算符名称和nnvm一致,并且运算符附带的参数名称也必须一致。
  • _convert_map:表示mxnet运算符名称或者参数名称和nnvm不一致,必须转换运算符名称或参数名称。

python/nnvm/frontend/mxnet.py

_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
                  '__div_symbol__', '__mul_scalar__', '__mul_symbol__',
                  '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
                  '__rsub_scalar__', '__sub_scalar__', '__sub_symbol__',
                  'broadcast_add', 'broadcast_div', 'broadcast_mul',
                  'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
                  'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
                  'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
                  'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']

# _convert_map列表较长,只列出部分运算符
_convert_map = {
    'Activation'    : _activations,
    'BatchNorm'     : _batch_norm,
    'BatchNorm_v1'  : _batch_norm,
    'Cast'          : _rename('cast'),
    'Concat'        : _concat,
    'Convolution'   : _conv2d,
    'Convolution_v1': _conv2d,
    'Deconvolution' : _conv2d_transpose,
    'Dropout'       : _dropout,
}

获取到运算符名称之后可以通过_get_nnvm_op函数来获取nnvm运算符

python/nnvm/frontend/mxnet.py

from .. import symbol as _sym

def _get_nnvm_op(op_name):
    op = getattr(_sym, op_name)
    if not op:
        raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
    return op

_get_nnvm_op的主要功能是通过getattr内建函数来获取nnvm op对象,这个函数能获取到所有通过NNVM_REGISTER_OP注册的运算符

猜你喜欢

转载自blog.csdn.net/sanallen/article/details/80315925