tensorflow之数据读取 -- 用tf.data通过tfrecord读取数据或者直接读取数据

对于数据量很大的数据集, 直接读入内存可能会放不下, 建议的做法是把全部数据转换成tfrecord的格式, 方便神经网络读取数据, 并且从tfrecord中读取数据的话tensorflow专门做过优化, 能加快读取速度.

参考资料: 官方tfrecord读写教程

1. 生成tfrecord

方法1: 直接以二进制bytes读取图片, 然后放进tfrecord中, 但是这样对bytes没法做修改, 比如有时候label需要进行map, 这时候就要用方法2.

import tensorflow as tf

# 把一个byte数据转换成一个bytes_list
def _bytes_list_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 把一对features和label转换成一个tfexample
def image_seg_to_tfexample(image_data, seg_data):
  return tf.train.Example(features=tf.train.Features(feature={
      'image': _bytes_list_feature(image_data),
      'label': _bytes_list_feature(seg_data),
  }))

# 解析并读取image和label成二进制byte类型, image_data = open(image_filename, 'rb').read()有相同的效果
image_data = tf.gfile.GFile(image_filename, 'rb').read() # type(image_data)为bytes
seg_data = tf.gfile.GFile(seg_filename, 'rb').read() 
# image_data = tf.read_file(image_filename) 也行, type(image_data)也是bytes

with tf.python_io.TFRecordWriter(output_filename) as writer:
  example = image_seg_to_tfexample(image_data,seg_data)
  # 把tfexample写入tfrecord中
  writer.write(example.SerializeToString())

方法2: 不直接把图片读取成bytes, 而是转换成ndarray, 这样可以对ndarray进行修改, 再写入tfrecord中.

from PIL import Image
import numpy as np
import tensorflow as tf

  # 读取已经保存好的字典, 后面用于map
with open('/home/steven/deeplab_v3+_project/deeplab_v3+_tensorflow_from_rishizek/map_dictionary.pickle', 'rb') as f:
  map_dict = pickle.load(f)

# 读取image成ndarray,注意读取的时候dtype设置为np.uint8, 因为像素值在0-255之间
image_data = np.array(Image.open(image_filename)).astype(np.uint8)
# 将image从ndarray变成bytes, 方便写入tfrecord
image_data = image_data.tostring()
  
# 读取label成ndarray,先不转换np.uint8, 因为map可能改变dtype
seg_data = np.array(Image.open(seg_filename))
# 对ndarray做map
seg_data_mapped = np.vectorize(map_dict.get)(seg_data)
# 将seg_data_mapped从ndarray变成bytes, 方便写入tfrecord, 注意先把数据也转换成np.uint8再变成bytes
seg_data = seg_data_mapped.astype(np.uint8).tostring()
  
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  # image_seg_to_tfexample()函数见方法1
  example = image_seg_to_tfexample(image_data,seg_data)
  tfrecord_writer.write(example.SerializeToString())

2. 读取tfrecord: tf.data.TFRecordDataset()

# 返回一个list, 包含所有要输入的tfrecord文件
def get_filenames(is_training, data_dir):
    if is_training:
        return [os.path.join(data_dir, 'nonzeros_train-00000-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_train-00001-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_train-00002-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_train-00003-of-00004.tfrecord')]
    else:
        return [os.path.join(data_dir, 'nonzeros_valid-00000-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_valid-00001-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_valid-00002-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_valid-00003-of-00004.tfrecord')]

# 读取所有tfrecord文件得到dataset
dataset = tf.data.TFRecordDataset(get_filenames(is_training,data_dir))

# 解析dataset的函数, 直接把bytes转换回image, 对应方法1
def parse_record(raw_record):
	# 按什么格式写入的, 就要以同样的格式输出
	keys_to_features = {
      'image': tf.FixedLenFeature((), tf.string),
      'label': tf.FixedLenFeature((), tf.string),
    }
	# 按照keys_to_features解析二进制的
    parsed = tf.parse_single_example(raw_record, keys_to_features)
    
    image = tf.image.decode_image(tf.reshape(parsed['image'], shape=[]), 1)
    image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
    image.set_shape([None, None, 1])
    label = tf.image.decode_image(tf.reshape(parsed['label'], shape=[]), 1)
    label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
    label.set_shape([None, None, 1])
    
    return image, label

# 直接把bytes类型的ndarray解析回来, 用decode_raw(),对应方法2
def parse_record(raw_record):
    keys_to_features = {
      'image': tf.FixedLenFeature((), tf.string),
      'label': tf.FixedLenFeature((), tf.string),
    }
    parsed = tf.parse_single_example(raw_record, keys_to_features)

    image = tf.decode_raw(parsed['image'], tf.uint8)
    image = tf.to_float(image)
    image = tf.reshape(image, [256,256,1])
    label = tf.decode_raw(parsed['label'], tf.uint8)
    label = tf.to_int32(label)
    label = tf.reshape(label, [256,256,1])

    return image, label

# 对dataset中的每条数据, 应用parse_record函数, 得到解析后的新的dataset
dataset = dataset.map(parse_record)
# 对dataset中的每条数据, 应用lambda函数, 输入image, label, 用preprocess_image()函数(省略没写)处理,得到新的dataset
dataset = dataset.map(lambda image, label: preprocess_image(image, label, is_training))
# dataset还可以做repeat(), shuffle(), batch()等处理
dataset = dataset.shuffle(buffer_size).repeat(num_epochs).batch(batch_size)
# 每次sess.run(images, labels)得到一个batch_size的images和labels
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()

3. 通过tf.data.Dataset直接读取数据

def eval_input_fn(image_filenames, label_filenames=None, batch_size=1):
  # Reads an image from a file, decodes it into a dense tensor
  def _parse_function(filename, is_label):
    if not is_label:
      image_filename, label_filename = filename, None
    else:
      image_filename, label_filename = filename
    
    # 这里与上面1和2中读取与解析图片的过程类似, 区别在于tf.readfile()得到的bytes文件没有放入tfrecord中, 而是通过tf.image.decode_image()直接解析成tensor
    image_string = tf.read_file(image_filename)
    image = tf.image.decode_image(image_string)
    image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
    image.set_shape([None, None, 3])

    if not is_label:
      return image
    else:
      # 读取与解析label
      label_string = tf.read_file(label_filename)
      label = tf.image.decode_image(label_string)
      label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
      label.set_shape([None, None, 1])

      return image, label
      
  if label_filenames is None:
    input_filenames = image_filenames
  else:
    input_filenames = (image_filenames, label_filenames)
  
  # input_filenames是一个文件名组成的list或者一个由两个list组成的元组, 这里通过tf.data.Dataset.from_tensor_slices()直接获得文件名组成的dataset
  dataset = tf.data.Dataset.from_tensor_slices(input_filenames)
  # 通过map函数, 解析dataset中的文件名形成一个新的dataset
  if label_filenames is None:
    dataset = dataset.map(lambda x: _parse_function(x, False))
  else:
    dataset = dataset.map(lambda x, y: _parse_function((x, y), True))
  dataset = dataset.prefetch(batch_size)
  dataset = dataset.batch(batch_size)
  iterator = dataset.make_one_shot_iterator()

  if label_filenames is None:
    images = iterator.get_next()
    labels = None
  else:
    images, labels = iterator.get_next()

  return images, labels

猜你喜欢

转载自blog.csdn.net/weixin_42561002/article/details/88100573