mx.metric.EvalMetric和mx.io.DataIter学习笔记

首先看下metric继承类class MAE_zz(mx.metric.EvalMetric)代码:

class MAE_zz(mx.metric.EvalMetric):

    def __init__(self, name = None):
        self.name = "mae"
        super(MAE_zz, self).__init__('mae') # 调用其父类,初始化父类的所有数据成员,并将父类名字命名为“mae”

    def reset(self): # 初始化总数(batch_size*num_task)和准确率
        """Resets the internal evaluation result to initial state."""
        self.num_inst = [0]*3 # [batch_size/num(ctx), b../n.., ...]
        self.sum_metric = [0.0]*3 # 每个任务(网络模型最后的输出层mx.symbol.Group)的准确率

    def update(self, labels, preds): # 每一个batch更新metric,便于mae的输出,不影响网络的训练
        """Updates the internal evaluation result."""

        for i in range(len(labels)): # i代表每个任务
            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32') # 得到pred_label(label的预测值,shape=batch_size/ctx);preds[i]代表第i个任务的预测值(shape=batch_size/ctx乘以第i个任务的分类个数)。
            label = labels[i].asnumpy().astype('int32') # 得到label(真实label);labels为多个任务的真实label。
            mx.metric.check_label_shapes(label, pred_label) # 检查pred_label和label的shape是否一致,不一致则raise ValueError
            self.sum_metric[i] += (pred_label.flat == label.flat).sum() # 计算一个batch里,每i个任务下的预测正确数
            self.num_inst[i] += len(pred_label.flat) # batch_size * len(i)

    def get(self): # 得到每个batch下所有ctx的准确率
        """Gets the current evaluation result.
        """
        if self.num_inst == 0:
            return self.name, float('nan')
        else:
            return zip(*(('%s-task%d' % (self.name, i), float('nan')
            if self.num_inst[i] == 0
            else self.sum_metric[i] / self.num_inst[i])
                         for i in range(3))) # 返回每个任务的准确率,并zip为mae-task[i] 和 self.sum_metric[i] / self.num_inst[i]

    def get_name_value(self):
        """Returns zipped name and value pairs.
        """
        name, value = self.get() # name=mae-task[i]; value=self.sum_metric[i] / self.num_inst[i]
        return list(zip(name, value))

其次看下迭代器继承类class Multi_mnist_iterator(mx.io.DataIter)代码:

class Multi_mnist_iterator(mx.io.DataIter):
    '''multi label mnist iterator'''

    def __init__(self, data_iter):
        super(Multi_mnist_iterator, self).__init__()
        self.data_iter = data_iter # 初始化data_iter(DataIter迭代器的数据)
        self.batch_size = self.data_iter.batch_size # 得到DataIter的batch_size

    @property # 调用property装饰器的getter获取data_iter的provide_data(**见图1**)
    def provide_data(self):
        return self.data_iter.provide_data

    @property # 调用property装饰器的getter获取data_iter的provide_label(**见图1**)
    def provide_label(self):
        provide_label = self.data_iter.provide_label[0]
        # Different labels should be used here for actual application
        return [('softmax1_label',(self.batch_size,)),
                ('softmax2_label',(self.batch_size,)),
                ('softmax3_label',(self.batch_size,))]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch = self.data_iter.next() # 得到data_iter迭代器实例(**见图2**)
        label = batch.label[0] # label.shape 为batch * num_task
        data = batch.data[0] # data.shape为batch*num_task*weigh*height
        label0, label1, label2 = label.T # 将label变为num_task * batch 输入给模型;以便模型分任务计算准确率。
        return mx.io.DataBatch(data=[data], label=[label0, label1, label2],
                               pad=batch.pad, index=batch.index)

图一:data_iter 迭代器
在这里插入图片描述

图二:batch迭代器实例
在这里插入图片描述

@property 的作用

    def __init__(self, fget=None, fset=None, fdel=None, doc=None): # known special case of property.__init__
        """
        property(fget=None, fset=None, fdel=None, doc=None) -> property attribute       
        class C(object):
            @property
            def x(self):
                "I am the 'x' property."
                return self._x
            @x.setter
            def x(self, value):
                self._x = value
            @x.deleter
            def x(self):
                del self._x
   
        # (copied from class doc)
        """
        pass
  • 看上面代码可知,把x方法(返回x的值)变为属性只需要加上@property装饰器即可,即:C.x() => C.x;
  • 此时@property本身又会创建另外一个装饰器@x.setter,负责把x方法变成给属性赋值,即:C.x(value) => C.x=value
  • 此时@property本身又会创建另外一个装饰器@x.deleter,负责删除C中的元素,即:del C.x。
  • 此3种属性对应于property的3种方法getter,setter 和 deleter 。

猜你喜欢

转载自blog.csdn.net/weixin_42486685/article/details/83904057