TensorFlow TFRecord encapsulates variable-length sequence data (text)

TensorFlow TFRecord encapsulates variable-length sequence data (text)

In the laboratory environment, the data is usually imported into the memory at one time, and then the hand-written data mini-batch function is used to segment the data, but this approach is not suitable for massive data: 1) The memory is too small and insufficient In order to import all the data at one time; 2) There is no asynchronous between data segmentation and model training, and the training process is easily blocked by the time-consuming and time-consuming data mini-batch segmentation. 3) Unable to deploy to a distributed environment

The following code snippet adopts the data file format of TFrecord, and supports variable-length sequences and dynamic padding, which can basically meet the needs of tasks with sequence requirements such as NLP.

import tensorflow as tf


def generate_tfrecords(tfrecod_filename):
    sequences = [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
                 [1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]
    labels = [1, 2, 3, 4, 5, 1, 2, 3, 4]

    with tf.python_io.TFRecordWriter(tfrecod_filename) as f:
        for feature, label in zip(sequences, labels):
            frame_feature = list(map(lambda id: tf.train.Feature(int64_list=tf.train.Int64List(value=[id])), feature))

            example = tf.train.SequenceExample(
                context=tf.train.Features(feature={
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}),
                feature_lists=tf.train.FeatureLists(feature_list={
                    'sequence': tf.train.FeatureList(feature=frame_feature)
                })
            )
            f.write(example.SerializeToString())


def parser(tfrecord_filename, num_epochs=None):
    _, serialized_example = tf.TFRecordReader().read(
        tf.train.string_input_producer(tfrecord_filename, num_epochs=num_epochs))

    context_features = {
        "label": tf.FixedLenFeature([], dtype=tf.int64)
    }
    sequence_features = {
        "sequence": tf.FixedLenSequenceFeature([], dtype=tf.int64)
    }

    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=serialized_example,
        context_features=context_features,
        sequence_features=sequence_features
    )

    labels = context_parsed['label']
    sequences = sequence_parsed['sequence']
    return [sequences, labels]


def _shuffle_inputs(input_tensors, capacity, min_after_dequeue, num_threads):
    """Shuffles tensors in `input_tensors`, maintaining grouping."""
    shuffle_queue = tf.RandomShuffleQueue(
        capacity, min_after_dequeue, dtypes=[t.dtype for t in input_tensors])
    enqueue_op = shuffle_queue.enqueue(input_tensors)
    runner = tf.train.QueueRunner(shuffle_queue, [enqueue_op] * num_threads)
    tf.train.add_queue_runner(runner)

    output_tensors = shuffle_queue.dequeue()

    for i in range(len(input_tensors)):
        output_tensors[i].set_shape(input_tensors[i].shape)

    return output_tensors


def count_records(file_list, stop_at=None):
    """Counts number of records in files from `file_list` up to `stop_at`.
    Args:
      file_list: List of TFRecord files to count records in.
      stop_at: Optional number of records to stop counting at.
    Returns:
      Integer number of records in files from `file_list` up to `stop_at`.
    """
    num_records = 0
    for tfrecord_file in file_list:
        tf.logging.info('Counting records in %s.', tfrecord_file)
        for _ in tf.python_io.tf_record_iterator(tfrecord_file):
            num_records += 1
            if stop_at and num_records >= stop_at:
                tf.logging.info('Number of records is at least %d.', num_records)
                return num_records
    tf.logging.info('Total records: %d', num_records)
    return num_records


def flatten_maybe_padded_sequences(maybe_padded_sequences, lengths=None):
    """Flattens the batch of sequences, removing padding (if applicable).

    Args:
      maybe_padded_sequences: A tensor of possibly padded sequences to flatten,
          sized `[N, M, ...]` where M = max(lengths).
      lengths: Optional length of each sequence, sized `[N]`. If None, assumes no
          padding.

    Returns:
       flatten_maybe_padded_sequences: The flattened sequence tensor, sized
           `[sum(lengths), ...]`.
    """

    def flatten_unpadded_sequences():
        # The sequences are equal length, so we should just flatten over the first
        # two dimensions.
        return tf.reshape(maybe_padded_sequences,
                          [-1] + maybe_padded_sequences.shape.as_list()[2:])

    if lengths is None:
        return flatten_unpadded_sequences()

    def flatten_padded_sequences():
        indices = tf.where(tf.sequence_mask(lengths))
        return tf.gather_nd(maybe_padded_sequences, indices)

    return tf.cond(
        tf.equal(tf.reduce_min(lengths), tf.shape(maybe_padded_sequences)[1]),
        flatten_unpadded_sequences,
        flatten_padded_sequences)


def batched_data(tfrecord_filename, parser, batch_size, num_epochs, num_enqueuing_threads=4, shuffle=True,
                 QUEUE_CAPACITY=100):
    SHUFFLE_MIN_AFTER_DEQUEUE = QUEUE_CAPACITY // 5

    if type(tfrecord_filename) not in (list, tuple): tfrecord_filename = [tfrecord_filename]
    input_tensors = parser(tfrecord_filename, num_epochs)
    if shuffle:
        assert num_enqueuing_threads >= 2, '`num_enqueuing_threads` must be at least 2 when shuffling.'
        shuffle_threads = num_enqueuing_threads // 2
        # Since there may be fewer records than SHUFFLE_MIN_AFTER_DEQUEUE, take the
        #  minimum of that number and the number of records.
        min_after_dequeue = count_records(
            tfrecord_filename, stop_at=SHUFFLE_MIN_AFTER_DEQUEUE)

        input_tensors = _shuffle_inputs(
            input_tensors, capacity=QUEUE_CAPACITY,
            min_after_dequeue=min_after_dequeue,
            num_threads=shuffle_threads)
        num_enqueuing_threads -= shuffle_threads
    return tf.train.batch(
        input_tensors,
        batch_size=batch_size,
        capacity=QUEUE_CAPACITY,
        num_threads=num_enqueuing_threads,
        dynamic_pad=True,
        allow_smaller_final_batch=False)


if __name__ == "__main__":
    def model(features, labels):
        return labels


    tfrecord_filename = 'test.tfrecord'
    generate_tfrecords(tfrecord_filename)
    out = model(*batched_data(tfrecord_filename, parser, 2, 1))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            while not coord.should_stop():
                print(sess.run(out))

        except tf.errors.OutOfRangeError:
            print("done training")
        finally:
            coord.request_stop()
        coord.join(threads)

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=326118929&siteId=291194637