【MXNet官方教程5】Iterators-加载数据

在这篇教程里,我们关注将数据放入训练或预测模型。大部分MXNet的训练和预测模型支持数据迭代器(Iterators),它简化了数据加载过程,尤其是读取大量数据的时候。这里我们介绍一下API规范和几个定义好的迭代器。

先决条件

我们需要:

  • MXNet
  • OpenCV Python library, Python Requests, Matplotlib 和 Jupyter Notebook.
$ pip install opencv-python requests matplotlib jupyter
  • 设置MXNET_HOME环境变量为MXNet源代码目录
$ git clone https://github.com/dmlc/mxnet ~/mxnet
$ export MXNET_HOME='~/mxnet'

MXNet数据迭代器

MXNet里的数据迭代器和Python对象迭代器差不多。在Python里,iter方法通过调用Python迭代器对象(比如list)的next()方法来顺序读取元素。迭代器给各种可迭代对象提供了抽象的接口,从而不必要暴露底层数据结构。

在MXNet里,每次调用数据迭代器的next方法会返回一批数据DataBatch。一个DataBatch通常包含n个训练样本和对应的标签,这个n称为迭代器的batch_size。当数据流末尾没有数据可读时,迭代器和Python iter一样抛出StopIteration异常。DataBatch的结构定义在这里

样本和标签的信息,比如name,shape,type和layout,由数据描述对象DataDesc提供,DataDesc对象可以由DataBatchprovide_dataprovide_label属性得到。DataDesc的结构定义在这里

MXNet的所有IO都集中在mx.io.DataIter和其子类里。在这篇教程里,我们将讨论MXNet提供的几个常用的数据迭代器。

在此之前,我们先导入一些必要的包

import mxnet as mx
%matplotlib inline
import os
import sys
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

从内存中读取数据

当数据已加载到内存后,不管是NDArray还是numpy的ndarray,我们可以用NDArrayIter读取:

import numpy as np
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
    print([batch.data, batch.label, batch.pad])
[[
[[ 0.6530531   0.46522644  0.51619464]
 [ 0.16499737  0.76653564  0.03481427]
 [ 0.74804795  0.52990937  0.26833427]
 [ 0.66638368  0.3022213   0.06241459]
 [ 0.91972476  0.55546296  0.66393465]
 [ 0.08645096  0.38091755  0.56270498]
 [ 0.64666849  0.21491572  0.35175693]
 [ 0.86468941  0.47517869  0.05744886]
 [ 0.67380011  0.05342721  0.4926973 ]
 [ 0.19055551  0.69065297  0.67532688]
 [ 0.02628353  0.6844967   0.19746889]
 [ 0.66295046  0.09987242  0.58984101]
 [ 0.9458763   0.81699216  0.1675925 ]
 [ 0.46482503  0.21350947  0.06471642]
 [ 0.89144808  0.94313467  0.68858165]
 [ 0.81668401  0.7621479   0.27384126]
 [ 0.63461167  0.65230727  0.97777712]
 [ 0.22063005  0.24458201  0.10742629]
 [ 0.92816764  0.13466544  0.04408605]
 [ 0.30227584  0.14775786  0.87613076]
 [ 0.63119137  0.81612813  0.92117757]
 [ 0.94886911  0.43556175  0.46657735]
 [ 0.27003673  0.76513189  0.23725513]
 [ 0.2746343   0.47627121  0.54125744]
 [ 0.25552508  0.01837774  0.39958724]
 [ 0.83126527  0.03519598  0.34842646]
 [ 0.2845566   0.64368129  0.46485582]
 [ 0.08218119  0.41793332  0.51502693]
 [ 0.09406272  0.91428256  0.31059062]
 [ 0.13820368  0.17033553  0.28657481]]
<NDArray 30x3 @cpu(0)>], [
[ 7.  0.  0.  2.  0.  7.  0.  5.  8.  6.  8.  2.  1.  4.  5.  1.  7.  0.
  9.  3.  6.  0.  4.  0.  5.  5.  4.  4.  1.  0.]
<NDArray 30 @cpu(0)>], 0]
[[
[[ 0.13700683  0.52412778  0.8671298 ]
 [ 0.23877266  0.33903974  0.27537507]
 [ 0.06243272  0.19531037  0.5281179 ]
 [ 0.82232028  0.56463391  0.17138779]
 [ 0.5569911   0.20909874  0.52542228]
 [ 0.60343611  0.23063758  0.81608468]
 [ 0.70935023  0.78453153  0.78045279]
 [ 0.11286663  0.79524058  0.09895906]
 [ 0.77800322  0.16236411  0.87968087]
 [ 0.01127686  0.69375426  0.08547841]
 [ 0.82750279  0.0399946   0.60687792]
 [ 0.53598893  0.78744203  0.96958113]
 [ 0.82449615  0.11746258  0.29264763]
 [ 0.4362683   0.64713514  0.10649233]
 [ 0.14894192  0.90637457  0.13595931]
 [ 0.08151129  0.23844923  0.09844355]
 [ 0.33128792  0.7256636   0.84794742]
 [ 0.85739374  0.19100513  0.895199  ]
 [ 0.21185175  0.80707216  0.97806442]
 [ 0.64928633  0.63623446  0.12098809]
 [ 0.39623061  0.78550547  0.52882141]
 [ 0.13212556  0.1327759   0.27480963]
 [ 0.9550342   0.47325855  0.08431709]
 [ 0.7431556   0.03889066  0.39910018]
 [ 0.52704382  0.44965392  0.76548541]
 [ 0.69834208  0.30493379  0.17661361]
 [ 0.60261613  0.54987943  0.25630507]
 [ 0.54871225  0.27588007  0.58100933]
 [ 0.44057187  0.44038841  0.16040143]
 [ 0.90090007  0.6116063   0.17815863]]
<NDArray 30x3 @cpu(0)>], [
[ 3.  2.  1.  6.  7.  0.  9.  0.  3.  2.  7.  9.  3.  9.  9.  7.  0.  0.
  8.  3.  1.  7.  0.  7.  2.  5.  7.  7.  2.  2.]
<NDArray 30 @cpu(0)>], 0]
[[
[[ 0.38992238  0.38013816  0.08797289]
 [ 0.52365917  0.90862477  0.08777413]
 [ 0.2257922   0.86751235  0.90748298]
 [ 0.33275157  0.43818691  0.6017009 ]
 [ 0.60666031  0.2776112   0.05172342]
 [ 0.6747095   0.57716769  0.7660436 ]
 [ 0.17980796  0.22112246  0.11028604]
 [ 0.01789156  0.70913017  0.19826613]
 [ 0.5153966   0.89273322  0.36250469]
 [ 0.28154254  0.76787293  0.96268386]
 [ 0.41763589  0.36206847  0.44974893]
 [ 0.15868166  0.08014139  0.8982119 ]
 [ 0.47615659  0.41381609  0.7943607 ]
 [ 0.17594658  0.37957036  0.61883914]
 [ 0.67709279  0.857777    0.1074606 ]
 [ 0.69237596  0.84633547  0.50271791]
 [ 0.16945797  0.68244201  0.19087805]
 [ 0.1299092   0.49758923  0.37441683]
 [ 0.19249067  0.42367426  0.04579045]
 [ 0.74409103  0.40689459  0.38356331]
 [ 0.00723282  0.24958441  0.46890974]
 [ 0.61725456  0.08689262  0.64768285]
 [ 0.0569613   0.89262682  0.67263258]
 [ 0.17189403  0.5448252   0.02259486]
 [ 0.49834073  0.36860585  0.8104018 ]
 [ 0.63564551  0.62717378  0.90756214]
 [ 0.3532913   0.82186127  0.07672632]
 [ 0.72964108  0.71190619  0.22283019]
 [ 0.77529597  0.09597207  0.45330995]
 [ 0.10836289  0.07343143  0.02535379]]
<NDArray 30x3 @cpu(0)>], [
[ 7.  3.  8.  3.  1.  9.  8.  9.  9.  1.  3.  4.  7.  6.  7.  5.  4.  7.
  2.  3.  3.  3.  7.  2.  5.  2.  7.  2.  6.  2.]
<NDArray 30 @cpu(0)>], 0]
[[
[[ 0.79986835  0.65680069  0.76467651]
 [ 0.81243736  0.09769222  0.27826148]
 [ 0.42728153  0.97143823  0.02860877]
 [ 0.01750882  0.9944576   0.80612904]
 [ 0.47085366  0.35999826  0.48983538]
 [ 0.24489172  0.37438354  0.81461328]
 [ 0.72018224  0.4823792   0.02590245]
 [ 0.97203141  0.78433287  0.1679011 ]
 [ 0.86534786  0.07474694  0.96176213]
 [ 0.04950644  0.82327616  0.80272979]
 [ 0.6530531   0.46522644  0.51619464]
 [ 0.16499737  0.76653564  0.03481427]
 [ 0.74804795  0.52990937  0.26833427]
 [ 0.66638368  0.3022213   0.06241459]
 [ 0.91972476  0.55546296  0.66393465]
 [ 0.08645096  0.38091755  0.56270498]
 [ 0.64666849  0.21491572  0.35175693]
 [ 0.86468941  0.47517869  0.05744886]
 [ 0.67380011  0.05342721  0.4926973 ]
 [ 0.19055551  0.69065297  0.67532688]
 [ 0.02628353  0.6844967   0.19746889]
 [ 0.66295046  0.09987242  0.58984101]
 [ 0.9458763   0.81699216  0.1675925 ]
 [ 0.46482503  0.21350947  0.06471642]
 [ 0.89144808  0.94313467  0.68858165]
 [ 0.81668401  0.7621479   0.27384126]
 [ 0.63461167  0.65230727  0.97777712]
 [ 0.22063005  0.24458201  0.10742629]
 [ 0.92816764  0.13466544  0.04408605]
 [ 0.30227584  0.14775786  0.87613076]]
<NDArray 30x3 @cpu(0)>], [
[ 6.  4.  2.  5.  4.  5.  2.  3.  6.  1.  7.  0.  0.  2.  0.  7.  0.  5.
  8.  6.  8.  2.  1.  4.  5.  1.  7.  0.  9.  3.]
<NDArray 30 @cpu(0)>], 20]

从CSV文件读取数据

MXNet提供CSVIter读取CSV文件:

#lets save `data` into a csv file first and try reading it back
np.savetxt('data.csv', data, delimiter=',')
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
for batch in data_iter:
    print([batch.data, batch.pad])

自定义迭代器

当内置的迭代器不能满足应用需求时,可以创建自己的迭代器。

一个MXNet的迭代器应该满足:

  1. 实现Python2的next()或者Python3的__next()__,返回一个DataBatch或者在数据流末尾抛出StopIteration
  2. 实现reset()方法,从头开始重新读数据。
  3. 提供一个provide_data属性,包含一个DataDesc列表,每个DataDesc包含数据的name, shape, type 和 layout information信息(详见这里)。
  4. 提供一个provide_label属性,包含一个DataDesc列表,每个DataDesc包含标签的name, shape, type 和 layout information信息。

创建新迭代器时,你可以从头开始创建或者利用已经存在的迭代器。例如,在图片字幕应用里,数据样本是图片而标签是文本。所以我们可以这样创建新迭代器:

  • 用提供多线程数据预处理和增强的ImageRecordIter创建一个image_iter
  • NDArrayIter或者rnn包里的bucketing迭代器创建一个caption_iter
  • next()返回image_iter.next()caption_iter.next()的组合。

下面的代码展示了如何创建一个简单的迭代器:

class SimpleIter(mx.io.DataIter):
    def __init__(self, data_names, data_shapes, data_gen,
                 label_names, label_shapes, label_gen, num_batches=10):
        self._provide_data = list(zip(data_names, data_shapes))
        self._provide_label = list(zip(label_names, label_shapes))
        self.num_batches = num_batches
        self.data_gen = data_gen
        self.label_gen = label_gen
        self.cur_batch = 0

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch < self.num_batches:
            self.cur_batch += 1
            data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
            label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
            return mx.io.DataBatch(data, label)
        else:
            raise StopIteration

我们可以用上面的SimpleIter训练一个简单的MLP模型:

import mxnet as mx
num_classes = 10
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')
print(net.list_arguments())
print(net.list_outputs())
['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias', 'softmax_label']
['softmax_output']

这里,有四个需要学习的参数:全连接层fc1和fc2的weights和biases。有两个输入变量:训练样本data和标签softmax_label,最后还有一个输出softmax_output。

data是MXNet的Symbol API提供的变量,为了执行Symbol,它们需要先绑定数据。详见【MXNet官方教程3】Symbol -神经网络图和自动区分

我们通过MXNet的module API把迭代器的数据送入神经网络。详见【MXNet官方教程4】Module - 神经网络训练和预测

import logging
logging.basicConfig(level=logging.INFO)

n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
                  [lambda s: np.random.uniform(-1, 1, s)],
                  ['softmax_label'], [(n,)],
                  [lambda s: np.random.randint(0, num_classes, s)])

mod = mx.mod.Module(symbol=net)
mod.fit(data_iter, num_epoch=5)
INFO:root:Epoch[0] Train-accuracy=0.081250
INFO:root:Epoch[0] Time cost=0.006
INFO:root:Epoch[1] Train-accuracy=0.125000
INFO:root:Epoch[1] Time cost=0.005
INFO:root:Epoch[2] Train-accuracy=0.121875
INFO:root:Epoch[2] Time cost=0.006
INFO:root:Epoch[3] Train-accuracy=0.084375
INFO:root:Epoch[3] Time cost=0.004
INFO:root:Epoch[4] Train-accuracy=0.109375
INFO:root:Epoch[4] Time cost=0.005

使用Python3的注意事项:mxnet的许多方法用python2的字符串和python3的字节。为了保持教程的可读性,我们定义一个工具方法,将字符串转为python3的字节。

def str_or_bytes(str):
    """
    A utility function for this tutorial that helps us convert string 
    to bytes if we are using python3.

    Parameters
    ----------
    str : string

    Returns
    -------
    string (python2) or bytes (python3)
    """
    if sys.version_info[0] < 3:
        return str
    else:
        return bytes(str, 'utf-8')

Record IO

Record IO是MXNet 数据IO的一种文件格式。其对数据简洁的封装用于分布式文件系统比如 Hadoop HDFS 和 AWS S3的高效读写。更多内容参见这里

MXNet提供了MXRecordIOMXIndexedRecordIO用于顺序读取和随机读取数据。

MXRecordIO

首先,我们来看一下怎么用MXRecordIO顺序读取数据。文件以.rec为后缀。

record = mx.recordio.MXRecordIO('tmp.rec', 'w')
for i in range(5):
    record.write(str_or_bytes('record_%d'%i))

record.close()

我们可以用参数r读取文件:

record = mx.recordio.MXRecordIO('tmp.rec', 'r')
while True:
    item = record.read()
    if not item:
        break
    print (item)
record.close()
b'record_0'
b'record_1'
b'record_2'
b'record_3'
b'record_4'

MXIndexedRecordIO

MXIndexedRecordIO支持随机或者按索引读取数据。我们创建一个按索引记录文件和对应的索引文件:

record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w')
for i in range(5):
    record.write_idx(i, str_or_bytes('record_%d'%i))

record.close()

现在,我们可以用索引键读取对应的记录

record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r')
record.read_idx(3)
b'record_3'

你也可以把所有的索引键列出来:

record.keys
[0, 1, 2, 3, 4]

封装和解封装数据

.rec文件的每条记录都包含任意的二进制数据。然而,大多数深度学习任务需要数据以标签/样本的格式输入。mx.recordio包提供了一些工具方法,比如pack, unpack, pack_img, 和unpack_img

封装/解封装二进制数据

packunpack用于浮点数(或者一维浮点数向量)标签和二进制数据。数据和一个header包装在一起,header的结构定义在这里

# pack
data = 'data'
label1 = 1.0
header1 = mx.recordio.IRHeader(flag=0, label=label1, id=1, id2=0)
s1 = mx.recordio.pack(header1, str_or_bytes(data))

label2 = [1.0, 2.0, 3.0]
header2 = mx.recordio.IRHeader(flag=3, label=label2, id=2, id2=0)
s2 = mx.recordio.pack(header2, str_or_bytes(data))
# unpack
print(mx.recordio.unpack(s1))
print(mx.recordio.unpack(s2))
(HEADER(flag=0, label=1.0, id=1, id2=0), b'data')
(HEADER(flag=3, label=array([ 1., 2., 3.], dtype=float32), id=2, id2=0), b'data')

封装/解封装图片

MXNet提供pack_imgunpack_img来封装/解封装图片数据。pack_img封装的记录可以直接由mx.io.ImageRecordIter加载。

data = np.ones((3,3,1), dtype=np.uint8)
label = 1.0
header = mx.recordio.IRHeader(flag=0, label=label, id=0, id2=0)
s = mx.recordio.pack_img(header, data, quality=100, img_fmt='.jpg')
# unpack_img
print(mx.recordio.unpack_img(s))
(HEADER(flag=0, label=1.0, id=0, id2=0), array([[1, 1, 1],
       [1, 1, 1],
       [1, 1, 1]], dtype=uint8))

使用tools/im2rec.py

你可以用MXNet src/tools文件夹下的im2rec.py工具脚本直接将源图片转换为RecordIO格式。在下面的图片IO部分有一个使用此脚本的例子。

图片IO

这部分,我们将学习怎么预处理和加载图片。

有4种加载图片的方式:

  1. 使用mx.image.imdecode加载原图。
  2. 使用python实现的且易于定制的mx.img.ImageIter。它可以读取.rec文件和原图。
  3. 使用C++实现的mx.io.ImageRecordIter。它不那么容易定制,但是可以绑定各种语言使用。
  4. 自定义迭代器继承mx.io.DataIter

预处理图片

图片预处理有多种方式:

  • 使用mx.io.ImageRecordIter,很快但不那么灵活。适用与简单的任务比如图片识别,但在复杂的任务比如检测和分割行不通。
  • 使用mx.recordio.unpack_img (或者 cv2.imread, skimage等) + numpy,很灵活,但是比较慢(因为python全局解释器锁GIL)。
  • 使用MXNet提供的mx.image包。它将图片存储为NDArray并且启动MXNet的dependency engine自动并行处理,规避GIL。

下面,我们演示mx.image包里几种常见的预处理例子。

下载我们需要处理的图片

fname = mx.test_utils.download(url='http://data.mxnet.io/data/test_images.tar.gz', dirname='data', overwrite=False)
tar = tarfile.open(fname)
tar.extractall(path='./data')
tar.close()

加载原图

mx.image.imdecode加载图片,imdecode提供了类似于OpenCv的接口。
注意:为了使用mx.image.imdecode,你仍然需要安装OpenCV,而不是cv2 python库。

img = mx.image.imdecode(open('data/test_images/ILSVRC2012_val_00000001.JPEG', 'rb').read())
plt.imshow(img.asnumpy()); plt.show()


图片变换

# resize to w x h
tmp = mx.image.imresize(img, 100, 70)
plt.imshow(tmp.asnumpy()); plt.show()

# crop a random w x h region from image
tmp, coord = mx.image.random_crop(img, (150, 200))
print(coord)
plt.imshow(tmp.asnumpy()); plt.show()

使用图片迭代器加载数据

在使用两种内置的图片加载器之前,先获取一个包含101类对象的Caltech 101数据集并且转为RecordIO格式。下载并解压:

fname = mx.test_utils.download(url='http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz', dirname='data', overwrite=False)
tar = tarfile.open(fname)
tar.extractall(path='./data')
tar.close()

我们先看一眼数据,在根目录下(./data/101_ObjectCategories),每一个分类都有一个子文件夹(./data/101_ObjectCategories/yin_yang)。

现在我们使用im2rec.py脚本将图片转为RecordIO格式。首先,我们需要一个包含所有图片文件和分类的列表。

os.system('python %s/tools/im2rec.py --list=1 --recursive=1 --shuffle=1 --test-ratio=0.2 data/caltech data/101_ObjectCategories'%os.environ['MXNET_HOME'])

得到的列表文件(./data/caltech_train.lst)以index\t(one or more label)\tpath的格式。在这个例子里,每一个图片只有一个标签,但是你可以修改列表用于多标签训练(参见MXNet im2rec.py使用教程)。

7167    69.000000   okapi/image_0017.jpg
6153    52.000000   ibis/image_0073.jpg
7761    81.000000   scissors/image_0005.jpg
7792    81.000000   scissors/image_0036.jpg
1326    2.000000    Faces_easy/image_0425.jpg
...

然后我们可以使用这个列表来创建RecordIO文件。

os.system("python %s/tools/im2rec.py --num-thread=4 --pass-through=1 data/caltech data/101_ObjectCategories"%os.environ['MXNET_HOME'])

record io文件保存在这里(./data)。

使用ImageRecordIter

ImageRecordIter加载RecordIO格式的图片数据,只需要简单地创建一个加载实例:

data_iter = mx.io.ImageRecordIter(
    path_imgrec="./data/caltech.rec", # the target record file
    data_shape=(3, 227, 227), # output data shape. An 227x227 region will be cropped from the original image.
    batch_size=4, # number of samples per batch
    resize=256 # resize the shorter edge to 256 before cropping
    # ... you can add more augumentation options as defined in ImageRecordIter.
    )
data_iter.reset()
batch = data_iter.next()
data = batch.data[0]
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()

使用ImageIter

ImageIter是一个灵活的接口,支持从RecordIO和源文件加载图片。

data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227),
                              path_imgrec="./data/caltech.rec",
                              path_imgidx="./data/caltech.idx" )
data_iter.reset()
batch = data_iter.next()
data = batch.data[0]
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()

原文地址:Iterators - Loading data

猜你喜欢

转载自blog.csdn.net/xiang_freedom/article/details/79614571
今日推荐