首先简单说明一下用到的函数,自定义函数“_parse_function”可以看做一个解析函数,将包含路径的文件名称映射到对应的图片数据上,返回图片数据、图片名称和对应的类别。
“tf.data.Dataset.map()”其用法如下,这个映射将map_func应用于此数据集的每个元素,并返回一个包含已转换元素的新数据集,其顺序与输入中出现的顺序相同。
map(
map_func,
num_parallel_calls=None
)
- map_func:一个函数,作用是将一种数据结构映射到另一种数据结构
- num_parallel_calls:一个tf.int32标量的tf.Tensor,表示要异步并行处理的元素个数。 如果未指定,则将按顺序处理元素。 如果使用值tf.data.experimental.AUTOTUNE,则根据可用CPU动态设置并行调用的数量。
简单的例子是
a = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
b = a.map(lambda x: x + 1) # ==> [ 2, 3, 4, 5, 6 ]
但是数据集a和b可以拥有不同的数据结构,比如a是由字符串(如包含图片路径的图片名称)Tensor构成的数据集,而b是[height,width,channels]这样可以表示一张图片的数据构成的数据集。通过这样的方法就可以将路径直接映射成图片。
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
def _parse_function(filename,label):
# map Tensor constructed with filenames to Tensor made of images
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
image = tf.image.resize_images(image, [224, 224])
return image, filename, label
Dir = "C:\\Users\\Fj\\Desktop\\amazon\\images"
filenames =[]
label = []
for index,class_name in enumerate(os.listdir(Dir)):
class_path = Dir+"\\"+class_name+"\\"
for img_name in os.listdir(class_path):
img_path = class_path+img_name
filenames.append(img_path)
label.append(index)
batch_size = 32
# 文件名和label一定都要是Tensor!!!
images = tf.constant(filenames)
labels = tf.constant(label)
seed = 0 # seed相同,shuffle后对应关系不变
images = tf.random_shuffle(images, seed=seed)
labels = tf.random_shuffle(labels, seed=seed)
# 只需要将文件名和标签构成一个数据集
data = tf.data.Dataset.from_tensor_slices((images, labels))
# 关键点data.map将一个Tensor structure 映射成另一个Tensor structure!!!
data = data.map(_parse_function, num_parallel_calls=4)
data = data.prefetch(buffer_size=batch_size * 10) # 缓存区最大的样本个数
batched_data = data.batch(batch_size)
iterator = tf.data.Iterator.from_structure(batched_data.output_types,
batched_data.output_shapes)
init_op = iterator.make_initializer(batched_data)
images, filenames, labels = iterator.get_next()
with tf.Session() as sess:
sess.run(init_op)
for i in range(1):
try:
batch_image,batch_label = sess.run([images,labels])
print('batch_image\'shape is: {}'.format(batch_image.shape))
print('batch{}\' labels are -> {}'.format(i, sess.run(labels)))
except tf.errors.OutOfRangeError:
sess.run(init_op)
这种方法我感觉就很可以了,精华在于“data.map”将一个由Tensor构成的Dataset映射到另一个由Tensor构成的Dataset。这个映射将由文件名构成的Dataset映射到了由图片数据构成的Tensor上,这样我们在建立数据集的时候只需要建立一个文件名构成的数据集就可以了,其结果为:
该方法的优点:程序及其简单,没有什么坑,可以用于读取包含大量数据的数据集,因为其保存的不过是文件路径和类别标签,同时因为可以设置num_parallel_calls,读取速度也会相对较快。
好像TFRocrd还没有弄,先占坑