mxnet中增加新层(python)

mxnet中增加新层的途径有两个:
* 利用CustomOp,基于front-end语言(比如python)实现.如果实现中使用的都是mx.nd接口,则新层可以自由在CPU和GPU上运行,否则(比如用了mx.nd.asnumpy()复制数据到CPU)则只能在CPU上运行
* 利用c++/mshadow(CUDA)实现.可以获得最大性能提升,但是要求对mshadow/cuda/mxnet比较熟悉

下面主要介绍第一个途径: CustomOp

从mxnet.operator.CustomOp派生处一个子类,重载一些成员函数:

import os
import mxnet as mx
import numpy as np

class Softmax(mx.operator.CustomOp):
    def forward(self, is_train, req, in_data, out_data, aux):
        x = in_data[0].asnumpy()
        y = np.exp(x - x.max(axis=1).reshape((x.shape[0], 1)))
        y /= y.sum(axis=1).reshape((x.shape[0], 1))
        self.assign(out_data[0], req[0], mx.nd.array(y))

例子中重载了前向函数,输入输出是一些NDArray list. 例子中为了简单,调用了asnumpy()把数据复制到CPU上,这将损失性能.保持使用mx.nd接口可以获得最好的性能.
最后利用CustomOp.assign()复制结果数组y到out_data[0],支持req指定的各种赋值操作,包括write/add/null

同样重载后向函数.

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
    l = in_data[1].asnumpy().ravel().astype(np.int)
    y = out_data[0].asnumpy()
    y[np.arange(l.shape[0]), l] -= 1.0
    self.assign(in_grad[0], req[0], mx.nd.array(y))

Softmax类实现了自定义的层,现在需要使用mx.operator.CustomOpProp定义其输入/输出.首先注册新层,命名为softmax

@mx.operator.register("softmax")
class SoftmaxProp(mx.operator.CustomOpProp):

然后调用基类,设置参数need_top_grad=False. 因为softmax时损失层,不需要输入梯度信息.

def __init__(self):
    super(SoftmaxProp, self).__init__(need_top_grad=False)

声明输入/输出

def list_arguments(self):
    return ['data', 'label']

def list_outputs(self):
    return ['output']

list_argments()既声明了输入也声明了参数,推荐的先后次序是[input1,input2,…,weight1,weight2]
然后提供infer_shape()接口,声明输出/权重的尺寸,并检查输入尺寸
输入输出第一个尺度表示batch中不同样本.label是整数,和样本对应,输出和输入尺寸一致.infer_shape()函数的返回三个list,依次是: inputs, outputs and auxiliary states(本例子中没有出现),如果某一项不需要,设置成空list.
如果需要,还可以定义infer_type()接口声明输入/输出的类型.支持的类型是np.float32,np.float64,np.float16,np.uint8和np.int32.

def infer_type(self, in_type):
    dtype = in_type[0]
    return [dtype, dtype], [dtype], []

最后定义create_operator()函数,该函数被后端调用生成一个softmax实例

def create_operator(self, ctx, shapes, dtypes):
    return Softmax()

用如下方式调用新的层

mlp = mx.symbol.Custom(data=fc3, name='softmax', op_type='softmax')

完整的代码

猜你喜欢

转载自blog.csdn.net/z0n1l2/article/details/80853049