mxnet问题整理(二)

在mxnet中需要对conv进行修改,所以遇到了一些问题,选择难理解的问题记下来。

1. 修改完conv层函数之后,出现输出结果是null的问题

按照以下的方式来就好了

class new_conv(nn.Conv2D):
    def __init__(self, channels, kernel_size, **kwargs):
        # if isinstance(kernel_size, base.numeric_types):
        #     kernel_size = (kernel_size,)*2
        # assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints"
        super(new_conv, self).__init__(channels, kernel_size, **kwargs)

    def forward(self, x, *args):
        self.ctx = x.context
        self.set_params()
        return super(new_conv, self).forward(x, *args)

2. 网络模型参数获取

for key, val in self.params.items():
    # key = self.params.keys()
    if 'weight' in key:
        data = val.data()
        mask = netMask[key]

3. 网络参数的赋值

网络参数的普通的赋值可以使用set_data(),但是网络参数的处理一般都会在构建的网络图之中。所以,最好的采取一种不会再autograd.record()出现错误的处理方式。

def assign_params(self, key, multiplier):
    """Sets this parameter's value on all contexts."""
    # self.shape = multiplier.shape
    out = []
    for arr in self.params[key].list_data():
        out.append((nd.multiply(arr, multiplier)).as_in_context(self.ctx))
    if len(out) == 1:
        self.params[key]._data = out #这里需要格外的注意,必须是list形式。我在这儿捯饬了好大一会
    else:
        self.params[key]._data = [nd.stack(*out, axis=1)]
    print('o')

4. 对网络参数处理的说明

对网络参数的处理,可能还会遇到其他各种各样的问题,所以渔具是啥呢?

在/usr/local/lib/python2.7/dist-packages/mxnet/gluon/parameter.py中,可以找到很多对网络参数处理的环节。比如读取都会指向这么一个函数 def _check_and_get(self, arr_list, ctx),上面提到的赋值时需要的list格式,就是无论如何都是以list的形式出现的。这就有data和ctxlist配合的问题。

def _check_and_get(self, arr_list, ctx):
    if arr_list is not None:
        if ctx is list:
            return arr_list
        if ctx is None:
            if len(arr_list) == 1:
                return arr_list[0]
            else:
                ctx = context.current_context()
        ctx_list = self._ctx_map[ctx.device_typeid&1]
        if ctx.device_id < len(ctx_list):
            idx = ctx_list[ctx.device_id]
            if idx is not None:
                return arr_list[idx]##重点看这儿,arr_list是数据,所以外面要加[],变成list形式,不然就被取多维数据第一维度的第一个index。














猜你喜欢

转载自blog.csdn.net/daniaokuye/article/details/80893667