reader: 适用于原始数据数据形式的Tensorflow Reader
在库中parallel_reader.py是与reader相关的,它使用多个reader并行处理来提高速度,但其中定义的类是继承自基类,所以我们先看基类的功能。
class ParallelReader(io_ops.ReaderBase):
基类
基类是各种不同类型reader的基类,它将'work unit'转换为record,比较典型的’work unit'是文件名,records(键值对形式)就是从这些文件中提取的内容。我们想要每一步只产生一个record,但是显然一个'work unit'可能对应多个record,所以我们需要一个存储'work unit'的queue,只有reader读取完一个‘work unit'中所有record之后,它才能获取下一个'work unit'。
基类首先使用了一个装饰器,装饰器的主要功能是输出装饰的函数到TF的api,装饰的函数功能不受影响,所以我们可以不用关注它。
@tf_export("ReaderBase") class ReaderBase(object):
基类初始化函数检查是否支持eager execution,初始化类变量。
def __init__(self, reader_ref, supports_serialize=False): if context.executing_eagerly(): raise RuntimeError( "Readers are not supported when eager execution is enabled. " "Instead, please use tf.data to get data into your model.") self._reader_ref = reader_ref #实现reader的op self._supports_serialize = supports_serialize #reader是否支持序列化状态
read函数会返回一个record。
def read(self, queue, name=None): if isinstance(queue, ops.Tensor): queue_ref = queue else: queue_ref = queue.queue_ref if self._reader_ref.dtype == dtypes.resource: return gen_io_ops.reader_read_v2(self._reader_ref, queue_ref, name=name) else: # For compatibility with pre-resource queues, create a ref(string) tensor # which can be looked up as the same queue by a resource manager. old_queue_op = gen_data_flow_ops.fake_queue(queue_ref) return gen_io_ops.reader_read(self._reader_ref, old_queue_op, name=name)
read_up_to函数返回指定数量的records。
def read_up_to(self, queue, num_records, name=None): if isinstance(queue, ops.Tensor): queue_ref = queue else: queue_ref = queue.queue_ref if self._reader_ref.dtype == dtypes.resource: return gen_io_ops.reader_read_up_to_v2(self._reader_ref, queue_ref, num_records, name=name) else: # For compatibility with pre-resource queues, create a ref(string) tensor # which can be looked up as the same queue by a resource manager. old_queue_op = gen_data_flow_ops.fake_queue(queue_ref) return gen_io_ops.reader_read_up_to(self._reader_ref, old_queue_op, num_records, name=name)
num_records_produced函数返回已经产生的record数量。
def num_records_produced(self, name=None): if self._reader_ref.dtype == dtypes.resource: return gen_io_ops.reader_num_records_produced_v2(self._reader_ref, name=name) else: return gen_io_ops.reader_num_records_produced(self._reader_ref, name=name)
num_work_units_completed函数返回已经完成的work units数量。
def num_work_units_completed(self, name=None): if self._reader_ref.dtype == dtypes.resource: return gen_io_ops.reader_num_work_units_completed_v2(self._reader_ref, name=name) else: return gen_io_ops.reader_num_work_units_completed(self._reader_ref, name=name)
serialize_state函数产生编码reader状态的一个string tensor。
def serialize_state(self, name=None): if self._reader_ref.dtype == dtypes.resource: return gen_io_ops.reader_serialize_state_v2(self._reader_ref, name=name) else: return gen_io_ops.reader_serialize_state(self._reader_ref, name=name)
restore_state函数将reader状态加载为指定状态。
def restore_state(self, state, name=None): if self._reader_ref.dtype == dtypes.resource: return gen_io_ops.reader_restore_state_v2( self._reader_ref, state, name=name) else: return gen_io_ops.reader_restore_state(self._reader_ref, state, name=name)
reset_state函数将reader状态重置。
def reset(self, name=None): if self._reader_ref.dtype == dtypes.resource: return gen_io_ops.reader_reset_v2(self._reader_ref, name=name) else: return gen_io_ops.reader_reset(self._reader_ref, name=name)
还有supports_serialize函数和reader_ref函数则是返回相应类变量,至此基类中所有的函数都已经介绍完毕,可以看到,主要实现reader功能的是reader函数,其他函数则是查看reader状态。(由于gen_io_ops为c++代码实现,这里不予介绍。)
Parallel_reader
Parallel_reader是并行使用多个reader提高速度,它的构造函数利用参数构造多个reader,每个reader并行地读取不同的文件,异步地将item入队到队列(common_queue)中。对列中dtypes必须是[tf.string, tf.string]。
def __init__(self, reader_class, common_queue, num_readers=4, reader_kwargs=None): if len(common_queue.dtypes) != 2: raise TypeError('common_queue.dtypes must be [tf.string, tf.string]') for dtype in common_queue.dtypes: if not dtype.is_compatible_with(tf_dtypes.string): raise TypeError('common_queue.dtypes must be [tf.string, tf.string]') reader_kwargs = reader_kwargs or {} self._readers = [reader_class(**reader_kwargs) for _ in range(num_readers)] self._common_queue = common_queue
read函数用QueueRunner执行入队操作后,再从队列中返回一个record。
def read(self, queue, name=None): self._configure_readers_by(queue) return self._common_queue.dequeue(name=name)
read_up_to函数是返回多个record。
def read_up_to(self, queue, num_records, name=None): self._configure_readers_by(queue) return self._common_queue.dequeue_up_to(num_records, name) def _configure_readers_by(self, queue): enqueue_ops = [] for reader in self._readers: enqueue_ops.append(self._common_queue.enqueue(reader.read(queue))) queue_runner.add_queue_runner( queue_runner.QueueRunner(self._common_queue, enqueue_ops))
num_records_produced函数和num_work_units_completed函数,这两个函数功能和基类中同名函数一样。
def num_records_produced(self, name=None): num_records = [r.num_records_produced() for r in self._readers] return math_ops.add_n(num_records, name=name) def num_work_units_completed(self, name=None): num_work_units = [r.num_work_units_completed() for r in self._readers] return math_ops.add_n(num_work_units, name=name)
num_readers函数和common_queue函数则是返回相关属性,至此Parallel_reader函数中所有函数就全部介绍完了,可以看出它的主要作用只是对多个reader的工作使用多线程处理。
最新一次编辑在:21:36:28 2018-07-16