在mxnet中需要对conv进行修改,所以遇到了一些问题,选择难理解的问题记下来。
1. 修改完conv层函数之后,出现输出结果是null的问题
按照以下的方式来就好了
class new_conv(nn.Conv2D):
def __init__(self, channels, kernel_size, **kwargs):
# if isinstance(kernel_size, base.numeric_types):
# kernel_size = (kernel_size,)*2
# assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints"
super(new_conv, self).__init__(channels, kernel_size, **kwargs)
def forward(self, x, *args):
self.ctx = x.context
self.set_params()
return super(new_conv, self).forward(x, *args)
2. 网络模型参数获取
for key, val in self.params.items():
# key = self.params.keys()
if 'weight' in key:
data = val.data()
mask = netMask[key]
3. 网络参数的赋值
网络参数的普通的赋值可以使用set_data(),但是网络参数的处理一般都会在构建的网络图之中。所以,最好的采取一种不会再autograd.record()出现错误的处理方式。
def assign_params(self, key, multiplier):
"""Sets this parameter's value on all contexts."""
# self.shape = multiplier.shape
out = []
for arr in self.params[key].list_data():
out.append((nd.multiply(arr, multiplier)).as_in_context(self.ctx))
if len(out) == 1:
self.params[key]._data = out #这里需要格外的注意,必须是list形式。我在这儿捯饬了好大一会
else:
self.params[key]._data = [nd.stack(*out, axis=1)]
print('o')
4. 对网络参数处理的说明
对网络参数的处理,可能还会遇到其他各种各样的问题,所以渔具是啥呢?
在/usr/local/lib/python2.7/dist-packages/mxnet/gluon/parameter.py中,可以找到很多对网络参数处理的环节。比如读取都会指向这么一个函数 def _check_and_get(self, arr_list, ctx),上面提到的赋值时需要的list格式,就是无论如何都是以list的形式出现的。这就有data和ctxlist配合的问题。
def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
if ctx is list:
return arr_list
if ctx is None:
if len(arr_list) == 1:
return arr_list[0]
else:
ctx = context.current_context()
ctx_list = self._ctx_map[ctx.device_typeid&1]
if ctx.device_id < len(ctx_list):
idx = ctx_list[ctx.device_id]
if idx is not None:
return arr_list[idx]##重点看这儿,arr_list是数据,所以外面要加[],变成list形式,不然就被取多维数据第一维度的第一个index。