mxnet进阶 - mx.io.NDArrayIter 源码分析

介绍

想知道mxnet在训练过程或者验证过程中,如何通过iterator提供数据

几个问题:

  • 如何构造iterator
  • 从iterator中获取数据时,data_batch = next(iterator),输入输出是什么

分析mxnet自带的mx.io.NDArrayIter,看如何把一个NDArray转化为一个可以用于module.fit() 的 iterator

用于测试的代码,使用一个MLP学习mnist

'''
Loading Data
'''
import mxnet as mx
from collections import OrderedDict
from mxnet.ndarray import array
mnist = mx.test_utils.get_mnist()# dict
#'train_data' ndarray ,shape<class 'tuple'> (60000,1,28,28)
#'train_label' ndarray ,shape<class 'tuple'> (60000,)
#'test_data' ndarray ,shape<class 'tuple'> (10000,1,28,28)
#'test_label' ndarray ,shape<class 'tuple'> (10000,)
# Fix the seed
mx.random.seed(42)

# Set the compute context, GPU is available otherwise CPU
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()


batch_size = 100
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

'''
Training
'''

'''
这里的名字'data'不能改,对应于mx.io.NDArrayIter的defaltname参数就是'data',往后看就明白了
也可以改着看看bug信息
'''
data = mx.sym.var('data')
# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)
data = mx.sym.flatten(data=data)


# The first fully-connected layer and the corresponding activation function
fc1  = mx.sym.FullyConnected(data=data, num_hidden=128)
act1 = mx.sym.Activation(data=fc1, act_type="relu")

# The second fully-connected layer and the corresponding activation function
fc2  = mx.sym.FullyConnected(data=act1, num_hidden = 64)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
# MNIST has 10 classes
fc3  = mx.sym.FullyConnected(data=act2, num_hidden=10)
# Softmax with cross entropy loss
mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

import logging
logging.getLogger().setLevel(logging.DEBUG)  # logging to stdout
# create a trainable module on compute context
mlp_model = mx.mod.Module(symbol=mlp, context=ctx)
mlp_model.fit(train_iter,  # train data
              eval_data=val_iter,  # validation data
              optimizer='sgd',  # use SGD to train
              optimizer_params={'learning_rate':0.1},  # use fixed learning rate
              eval_metric='acc',  # report accuracy during training
              batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
              num_epoch=10)  # train for at most 10 dataset passes

看 mx.io.NDArrayIter.__init__()

 def __init__(self, data, label=None, batch_size=1, shuffle=False,
                 last_batch_handle='pad', data_name='data',
                 label_name='softmax_label'):
        super(NDArrayIter, self).__init__(batch_size)
        '''统一输入的格式为list(tuple(key,val),tuple(key,val)……)'''
        '''划重点!!!这个key和executor里的symbol对应的'''
        self.data = _init_data(data, allow_empty=False, default_name=data_name)
        self.label = _init_data(label, allow_empty=True, default_name=label_name)

        if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
                (last_batch_handle != 'discard')):
            raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
                                      " with `last_batch_handle` set to `discard`.")
        '''
        self.idx 是一个arrange生成的list,list大小为self.data[0][1].shape[0]
        从shape[0]可以看出输入的ndarray格式必须是num_data*data_instance
        shuffle data 打乱数据
        '''
        if shuffle:
            tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
            self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
            self.data = _shuffle(self.data, self.idx)
            self.label = _shuffle(self.label, self.idx)
        else:
            self.idx = np.arange(self.data[0][1].shape[0])
        '''如果选择'discard',则把输入的数据裁剪为batch_size的整数倍'''
        # batching
        if last_batch_handle == 'discard':
            new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size
            self.idx = self.idx[:new_n]
        '''把data和label关联成一个list=[data_0_ndarray,data_1_ndarray,……,label_0_ndarray,]'''
        self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
        '''输入输出一共有多少个ndarray mnist-2 '''
        self.num_source = len(self.data_list)
        '''数据量mnist-60000 for train'''
        self.num_data = self.idx.shape[0]
        '''batch_size 不能大于总数据量'''
        assert self.num_data >= batch_size, \
            "batch_size needs to be smaller than data size."
        '''定义一个光标'''
        self.cursor = -batch_size
        self.batch_size = batch_size
        '''最后不够一个batch_size时的处理方法'pad' or 'discard' '''
        self.last_batch_handle = last_batch_handle

_init_data 的作用是把输入的data统一格式,因为这个初始化data输入支持多种类型numpy.ndarray/mxnet.ndarray/h5py.Dataset输入,可以是单个的这些类型数据,也可能是他们的list输入

  • 比如输入一个mxnet.ndarray
  • 输出的格式为list[tuple(str{'_0_data'},mxnet.ndarray)]
  • 如果输入一个list:[mxnet.ndarray,mxnet.ndarray]
  • 输出格式为list[tuple(str{'_0_data'},mxnet.ndarray),tuple(str{'_1_data'},mxnet.ndarray)]
def _init_data(data, allow_empty, default_name):
    """Convert data into canonical form."""
    assert (data is not None) or allow_empty
    if data is None:
        data = []
    '''如果输入不是list,则把data转化为list,list中只有一个元素'''
    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
                  if h5py else (np.ndarray, NDArray)):
        data = [data]
    #type(data) = list
    '''接着把list转化为OrderedDict'''
    if isinstance(data, list):
        if not allow_empty:
            assert(len(data) > 0)
        if len(data) == 1:
            '''如果list中只有一个,即输入只有一个ndarray,
            Dict 只有一个元素, key 命名和参数 default_name一致,val 为该输入data
            data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type'''
        else:
            '''输入多个ndarray,则Dict中有多个元素,key命名格式为('_%d_%s' % (i, default_name)
            如:{('_0_data',ndarray),('_1_data',ndarray)……}'''
            data = OrderedDict( # pylint: disable=redefined-variable-type
                [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
    if not isinstance(data, dict):
        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
                "a list of them or dict with them as values")
    '''这里把非mxnet.ndarray输入,转换成mxnet.ndarray数据类型'''
    for k, v in data.items():
        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
            try:
                data[k] = array(v)
            except:
                raise TypeError(("Invalid type '%s' for %s, "  % (type(v), k)) + \
                                "should be NDArray, numpy.ndarray or h5py.Dataset")
    '''把Dict转成list,dict中的(key,val)变成tuple(key,val)'''
    return list(sorted(data.items()))

初始化完之后再mod.fit()当中使用到该iter的几个成员API:

  • mx.io.NDArrayIter.provide_data() # 用于module 或者 executor初始化
  • mx.io.NDArrayIter.provide_label() # 用于module 或者 executor初始化和上面同步
  • iter(mx.io.NDArrayIter) #得到一个迭代器,用于每次训练获取数据batch
  • mx.io.NDArrayIter.reset() #训练时每个epoch 结束时 reset一次
    @property
    def provide_data(self):
        """The name and shape of data provided by this iterator."""
        return [
            """
            DataDesc 是一个namedtuple,不知道啥的百度去……
            return DataDesc对象,初始化该对象使用了两个信息,首先看self.data结构
            self.data - list[tuple1(name1,val1),tuple2(name2,val2)……]
            其中name是symbol中的输入参数的name,上述初始化iterator时指定的
            val 是该参数的数据矩阵n*v
            DataDesc对象初始化时使用到 str(name) 和 tuple(batch_size,v)"""
            DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
            for k, v in self.data
        ]

Module / Executor 初始化时用到mx.io.NDArrayIter

mxnet.Module.fit().Module.bind()使用到了 mx.io.NDArrayIter.provide_data() mx.io.NDArrayIter.provide_label()

self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                  for_training=True, force_rebind=force_rebind)

'''解析数据描述子'''
    self._data_shapes, self._label_shapes = _parse_data_desc(
            self.data_names=['data'], 
            self.label_names=['Softmaxlabel'], 
            data_shapes=mx.io.NDArrayIter.provide_data(), 
            label_shapes=mx.io.NDArrayIter.provide_label())
'''这个数据解析器干两件事:
        把data attributes 转成DataDesc 格式
        检查输入的'数据属性表中的名字'是否和'网络输入symbol的名字'匹配
'''
def _parse_data_desc(data_names, label_names, data_shapes, label_shapes):
    """parse data_attrs into DataDesc format and check that names match"""
    data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
    _check_names_match(data_names, data_shapes, 'data', True)
    if label_shapes is not None:
        label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]
        _check_names_match(label_names, label_shapes, 'label', False)
    else:
        _check_names_match(label_names, [], 'label', False)
    return data_shapes, label_shapes

'''然后得到的数据描述子用于Executor group类的初始化'''
        self._exec_group = DataParallelExecutorGroup(self._symbol, self._context,
                                                     self._work_load_list, self._data_shapes,
                                                     self._label_shapes, self._param_names,
                                                     for_training, inputs_need_grad,
                                                     shared_group, logger=self.logger,
                                                     fixed_param_names=self._fixed_param_names,
                                                     grad_req=grad_req, group2ctxs=self._group2ctxs,
                                                     state_names=self._state_names)

    '''初始化里在最后一行代码用到这些数据描述子'''#shared_group =None
        self.bind_exec(data_shapes, label_shapes, shared_group)
    '''继续搜,上面的处理是为了多GPU并行处理而设定的,在这里把每一个GPU负责的batch,分给每一个Executor,并把这些Executor收集起来'''
        self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i,
                                                      shared_group))
    '''继续 这里是Module.bind()的终点,通过simple_bind得到一个Executor'''
    def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
        """Internal utility function to bind the i-th executor.
        This function utilizes simple_bind python interface.
        """
        shared_exec = None if shared_group is None else shared_group.execs[i]
        context = self.contexts[i]
        shared_data_arrays = self.shared_data_arrays[i]

        input_shapes = dict(data_shapes)
        if label_shapes is not None:
            input_shapes.update(dict(label_shapes))
        '''这里通过输入的data descriptor 得到一个字典 用于初始化 executor'''
        input_types = {x.name: x.dtype for x in data_shapes}
        if label_shapes is not None:
            input_types.update({x.name: x.dtype for x in label_shapes})

        group2ctx = self.group2ctxs[i]
        '''simple_bind
        后面开另外的blog仔细研究
        目前推测
        这里按照输入数据的描述子计算网络的静态图
        并在对应的context上分配对应的空间'''
        executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
                                           type_dict=input_types, shared_arg_names=self.param_names,
                                           shared_exec=shared_exec, group2ctx=group2ctx,
                                           shared_buffer=shared_data_arrays, **input_shapes)
        self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
        return executor
    

Module / Executor 训练时用到mx.io.NDArrayIter

这里重点关注每个epoch中,fit()如何调用mx.io.NDArrayIter提供数据用于训练的

for epoch in range(begin_epoch, num_epoch):
    tic = time.time()
    eval_metric.reset()
    nbatch = 0
    '''iter()是用于把一个可迭代的非iter对象变成迭代器,如list
    本blog中对train_data没有任何改变
    因为传入的train_data本身就是一个迭代器
    
    '''
    data_iter = iter(train_data)
    '''
    print(type(data_iter ),type(train_iter))
    <class 'mxnet.io.NDArrayIter'> <class 'mxnet.io.NDArrayIter'>
    '''
    '''初始化一个标识,用于检测iter是否到尾了'''
    end_of_batch = False
    '''获得一个batch大小的data,<class 'mxnet.io.DataBatch'>
    DataBatch 下有这两个重要的成员变量
        data <class 'list<class mx.NDArray>'>
        label <class 'list<class mx.NDArray>'>
    '''
    next_data_batch = next(data_iter)
    '''
    def next(self):
        #调用iter_next()判断迭代器是否到尾了
        if self.iter_next():
            ''''''
            return DataBatch(data=self.getdata(), label=self.getlabel(), \
                    pad=self.getpad(), index=None)
        else:
            raise StopIteration

    #self.cursor 初始化为-self.batch_size
    #这是为了在取数据时用到的cursor指向需要取的数据,如第一次next().getdata()时cursor = 0
    #每次调用next()则自加self.batch_size
    #如果记录值小于data的长度,则返回真,否则返回假
    def iter_next(self):
        self.cursor += self.batch_size
        return self.cursor < self.num_data
    
    '''
    
    while not end_of_batch:
        data_batch = next_data_batch
        if monitor is not None:
            monitor.tic()
        self.forward_backward(data_batch)
        self.update()
        '''处理完参数更新后,获取新的batch'''
        try:
            # pre fetch next batch
            next_data_batch = next(data_iter)
            self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
        '''except 和 raise对应'''
        except StopIteration:
            end_of_batch = True

重点看看iter.next().getdata()

    def _getdata(self, data_source):
        """Load data from underlying arrays, internal use only."""
        assert(self.cursor < self.num_data), "DataIter needs reset."
        '''判断iter中剩下的数据是否够一个batch'''
        if self.cursor + self.batch_size <= self.num_data:
            return [
                # np.ndarray or NDArray case
                '''data_source = self.data <class 'list<tuple<str_name,ndarray_data>>'>
                    取self.data list中所有成员中对应[cursor:cursor+data_batch]区间的数据
                '''
                x[1][self.cursor:self.cursor + self.batch_size]
                if isinstance(x[1], (np.ndarray, NDArray)) else
                # h5py (only supports indices in increasing order)
                array(x[1][sorted(self.idx[
                    self.cursor:self.cursor + self.batch_size])][[
                        list(self.idx[self.cursor:
                                      self.cursor + self.batch_size]).index(i)
                        for i in sorted(self.idx[
                            self.cursor:self.cursor + self.batch_size])
                    ]]) for x in data_source
            ]
        else:
            pad = self.batch_size - self.num_data + self.cursor
            return [
                # np.ndarray or NDArray case
                concatenate([x[1][self.cursor:], x[1][:pad]])
                if isinstance(x[1], (np.ndarray, NDArray)) else
                # h5py (only supports indices in increasing order)
                concatenate([
                    array(x[1][sorted(self.idx[self.cursor:])][[
                        list(self.idx[self.cursor:]).index(i)
                        for i in sorted(self.idx[self.cursor:])
                    ]]),
                    array(x[1][sorted(self.idx[:pad])][[
                        list(self.idx[:pad]).index(i)
                        for i in sorted(self.idx[:pad])
                    ]])
                ]) for x in data_source
            ]

总结

  • DataIter初始化时给定数据和数据名称要和网络输入的symbol名称对应
  • DataIter用于网络初始化和网络训练测试

猜你喜欢

转载自blog.csdn.net/qq_25379821/article/details/84973405