TensorFlow Distribution (data read and distributed in training)

Purpose of this article

In the introduction estimator distributed when the official document because the version update causes inconsistent with the interface. These are: distributed among the estimator, using dataset as a data input, in the 1.12 version, the training data only dataset of data, that is, all the equipment together, run again the data.

In version 2.0, the training data is multiplied by the fraction data dataset
number of distributed devices. In other words, in which each device will run again complete dataset of all data.

1.12 version reads

1. Create a diagram in which the main thread

The following code, the client calls the input function, get iterators. This is part of the code estimator distribute train calls

with ops.Graph().as_default() as g:
      # We want to create the iterations variable outside the distribution scope
      # as that is just stored on the host and mainly used to drive the loop
      # and doesn't need to be a Mirrored/Device variable.
      if is_tpu_strategy:
        steps_per_run_variable = training.get_or_create_steps_per_run_variable()
      with self._train_distribution.scope():
        random_seed.set_random_seed(self._config.tf_random_seed)
        iterator, input_hooks = self._get_iterator_from_input_fn(
            input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
  • _get_iterator_from_input_fn * This function will generate an iterator to read the data for subsequent training.
  def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
    if distribution is not None:
      result = distribution.distribute_dataset(
          lambda: self._call_input_fn(input_fn, mode))
    else:
      result = self._call_input_fn(input_fn, mode)

    iterator = result.make_initializable_iterator()
    input_hooks = [estimator_util._DatasetInitializerHook(iterator)]  # pylint: disable=protected-access
    return iterator, input_hooks

Here is invoked distribute_dataset generated dataset.
And then point to go see will look after the creation of such a PerDeviceDataset

class PerDeviceDataset(object):
  """Like `tf.data.Dataset` split devices, producing `PerDevice` data."""

  def __init__(self, dataset, devices, prefetch_on_device=None):
    self._devices = devices

    # Default to using prefetching in graph mode, unless specified.
    # TODO(priyag): Enable prefetching in eager mode.
    self._prefetch_on_device = prefetch_on_device
    if self._prefetch_on_device is None:
      self._prefetch_on_device = not context.executing_eagerly()
    assert not (self._prefetch_on_device and context.executing_eagerly()), (
        "Prefetching is only supported in graph mode currently")

    if self._prefetch_on_device:
      self._dataset = dataset.apply(
          prefetching_ops_v2.prefetch_to_devices(self._devices))
    else:
      # TODO(priyag): If dropping remainder is not appropriate, find another
      # approach to distributing the dataset when not possible to divide evenly.
      # Possibly not an issue when we start using PartitionedDataset.
      self._dataset = dataset.batch(len(devices), drop_remainder=True)

The last line of code can be seen on the original dataset and encapsulates the layer of batch. The number of the data segmentation apparatus.
Creating iterators are later packaged as PerDeviceDataIterator, forming a dictionary mapping of data different devices, in accordance with sub-index cutting batch.

Distributed training

Training in version 1.12 is relatively simple. For MirroredStrategy, it will give each a device to create a thread,
there is a drawback is that, every run will create a thread in the todo saw, follow-up will be optimized away.

The following is from the client acquires the data iterator is transmitted to each of the code to the computing device,
self._train_distribution.call_for_each_tower

features, labels = estimator_util.parse_iterator_result(
              iterator.get_next())
          grouped_estimator_spec = self._train_distribution.call_for_each_tower(
              self._call_model_fn,
              features,
              labels,  # although this will be None it seems
              model_fn_lib.ModeKeys.TRAIN,
              self.config)
          loss = self._train_distribution.unwrap(
              self._train_distribution.reduce(
                  distribute_lib.get_loss_reduction(),
                  grouped_estimator_spec.loss,
                  destinations='/device:CPU:0'))[0]
          distributed_train_op = grouped_estimator_spec.train_op

call_for_each_tower each training device interface

def _call_for_each_tower(distribution, fn, *args, **kwargs):
  """Run `fn` in separate threads, once per tower/worker device.
  run_concurrently = kwargs.pop("run_concurrently", True)
  if not context.executing_eagerly():
    # Lots of TF library code isn't thread-safe in graph mode, and
    # there is little to be gained by turning on multithreading when
    # constructing a graph.
    run_concurrently = False
    # Needed for per-thread device, etc. contexts in graph mode.
    ops.get_default_graph().switch_to_thread_local()
  elif run_concurrently is None:
    run_concurrently = True

  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))

  shared_variable_store = {}

  # TODO(isaprykin): Create these threads once instead of during every run()
  # call.
  threads = []
  for index, d in enumerate(distribution.worker_devices):
    variable_creator_fn = shared_variable_creator.make_fn(
        shared_variable_store, index)
    t = MirroredStrategy._MirroredTowerThread(  # pylint: disable=protected-access
        distribution, coord, d, variable_creator_fn, fn,
        *values.select_device(d, args), **values.select_device(d, kwargs))
    threads.append(t)

  for t in threads:
    t.start()

Wherein, select_device is the device key corresponding to the corresponding value is taken. Distributed to complete the training.

Guess you like

Origin www.cnblogs.com/axder/p/11459103.html