tensorflow queue

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013608336/article/details/82466772

tf官网解释链接
queue被实现为计算图中的节点,一个节点就像variable一样。
enqueue:入队,提前从硬盘读取数据
dequeue:出队,从队列中取一个mini_batch出来送进计算节点
主要有两个步骤
1.Multiple threads prepare training examples and enqueue them.
2.A training thread executes a training op that dequeues mini-batches from the queue

def simple_shuffle_batch(source, capacity, batch_size=10):
  # Create a random shuffle queue.
  queue = tf.RandomShuffleQueue(capacity=capacity,
                                min_after_dequeue=int(0.9*capacity),
                                shapes=source.shape, dtypes=source.dtype)

  # Create an op to enqueue one item.
  enqueue = queue.enqueue(source)

  # Create a queue runner that, when started, will launch 4 threads applying
  # that enqueue op.
  num_threads = 4
  qr = tf.train.QueueRunner(queue, [enqueue] * num_threads)

  # Register the queue runner so it can be found and started by
  # <a href="../../api_docs/python/tf/train/start_queue_runners"><code>tf.train.start_queue_runners</code></a> later (the threads are not launched yet).
  tf.train.add_queue_runner(qr)

  # Create an op to dequeue a batch
  return queue.dequeue_many(batch_size)

Once started by tf.train.start_queue_runners, or indirectly through tf.train.MonitoredSession, the QueueRunner will launch the threads in the background to fill the queue.

# create a dataset that counts from 0 to 99
input = tf.constant(list(range(100)))
input = tf.data.Dataset.from_tensor_slices(input)
input = input.make_one_shot_iterator().get_next()

# Create a slightly shuffled batch from the sorted elements
get_batch = simple_shuffle_batch(input, capacity=20)

# `MonitoredSession` will start and manage the `QueueRunner` threads.
with tf.train.MonitoredSession() as sess:
  # Since the `QueueRunners` have been started, data is available in the
  # queue, so the `sess.run(get_batch)` call will not hang.
  while not sess.should_stop():
    print(sess.run(get_batch))
[ 8 10  7  5  4 13 15 14 25  0]
[23 29 28 31 33 18 19 11 34 27]
[12 21 37 39 35 22 44 36 20 46]

prefetch_queue

该函数是simple_shuffle_batch函数的综合,创建一个queue runner,将之注册到QueueRunner中

def prefetch_queue(tensors,
                   capacity=8,
                   num_threads=1,
                   dynamic_pad=False,
                   shared_name=None,
                   name=None):
  """Creates a queue to prefetch tensors from `tensors`.
  A queue runner for enqueuing tensors into the prefetch_queue is automatically
  added to the TF QueueRunners collection.
  Example:
  This is for example useful to pre-assemble input batches read with
  `tf.train.batch()` and enqueue the pre-assembled batches.  Ops that dequeue
  from the pre-assembled queue will not pay the cost of assembling the batch.

  images, labels = tf.train.batch([image, label], batch_size=32, num_threads=4)
  batch_queue = prefetch_queue([images, labels])
  images, labels = batch_queue.dequeue()
  logits = Net(images)
  loss = Loss(logits, labels)
  Args:
    tensors: A list or dictionary of `Tensors` to enqueue in the buffer.
    capacity: An integer. The maximum number of elements in the queue. 
    #队列中的元素的最大个数
    num_threads: An integer.  Number of threads running the enqueue op.
    #enqueue的线程数
    dynamic_pad: Boolean.  Whether to allow variable dimensions in input shapes.
    shared_name: (optional). If set, this queue will be shared under the given
      name across multiple sessions.
    name: (Optional) A name for the operations.
  Returns:
    A queue from which you can dequeue tensors with the same type and shape
    as `tensors`.

tf.train.batch

cifar10.distorted_inputs()

猜你喜欢

转载自blog.csdn.net/u013608336/article/details/82466772
今日推荐