MXNet下,随机生成图片Iter

为了测试有时需要随机生成图片文件,为了自己predict需要去掉了label。
MXNet的module比较麻烦,一定需要使用DataIter,于是自己写了测试使用的DataIter

class RandomDataIter(mx.io.DataIter):
    def __init__(self, batch_size, max_iter=1000, dtype=np.float32, num_classes=None):
        self.batch_size = batch_size
        self.cur_iter = 0
        self.max_iter = max_iter
        self.dtype = dtype
        #label = np.random.RandomState().randint(0, num_classes, (self.batch_size,))
        self._pro_data()
        self.label = None#mx.nd.array(label, dtype=self.dtype, ctx=mx.Context('cpu_pinned', 0))

    def __iter__(self):
        return self

    @property
    def provide_data(self):
        return [mx.io.DataDesc('data', self.data.shape, self.dtype)]

    @property
    def provide_label(self):
        return [mx.io.DataDesc('softmax_label', (self.batch_size,), self.dtype)]

    def _pro_data(self):
        d = np.random.randint(224,225)
        l = np.random.randint(224,225)
        datas = np.random.RandomState().uniform(-1, 1, (self.batch_size,3,l,d))
        self.data = mx.nd.array(datas, dtype=self.dtype, ctx=mx.Context('cpu_pinned', 0))

    def next(self):
        self.cur_iter += 1
        self._pro_data()
        if self.cur_iter <= self.max_iter:
            return mx.io.DataBatch(data=(self.data,),
                             label=None,#(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=None)#self.provide_label)
        else:
            raise StopIteration

    def __next__(self):
        return self.next()

    def reset(self):
        self.cur_iter = 0

猜你喜欢

转载自blog.csdn.net/u011094454/article/details/80774593