为了测试有时需要随机生成图片文件,为了自己predict需要去掉了label。
MXNet的module比较麻烦,一定需要使用DataIter,于是自己写了测试使用的DataIter
class RandomDataIter(mx.io.DataIter):
def __init__(self, batch_size, max_iter=1000, dtype=np.float32, num_classes=None):
self.batch_size = batch_size
self.cur_iter = 0
self.max_iter = max_iter
self.dtype = dtype
#label = np.random.RandomState().randint(0, num_classes, (self.batch_size,))
self._pro_data()
self.label = None#mx.nd.array(label, dtype=self.dtype, ctx=mx.Context('cpu_pinned', 0))
def __iter__(self):
return self
@property
def provide_data(self):
return [mx.io.DataDesc('data', self.data.shape, self.dtype)]
@property
def provide_label(self):
return [mx.io.DataDesc('softmax_label', (self.batch_size,), self.dtype)]
def _pro_data(self):
d = np.random.randint(224,225)
l = np.random.randint(224,225)
datas = np.random.RandomState().uniform(-1, 1, (self.batch_size,3,l,d))
self.data = mx.nd.array(datas, dtype=self.dtype, ctx=mx.Context('cpu_pinned', 0))
def next(self):
self.cur_iter += 1
self._pro_data()
if self.cur_iter <= self.max_iter:
return mx.io.DataBatch(data=(self.data,),
label=None,#(self.label,),
pad=0,
index=None,
provide_data=self.provide_data,
provide_label=None)#self.provide_label)
else:
raise StopIteration
def __next__(self):
return self.next()
def reset(self):
self.cur_iter = 0