『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理

一、TFR数据读取

在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dataset对象

# Select the dataset.
# 'imagenet', 'train', tfr文件存储位置
# TFR文件命名格式:'voc_2012_%s_*.tfrecord',%s使用train或者test
dataset = dataset_factory.get_dataset(
    FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

获取过程会经过一系列臃肿的调用,我把中间被调用的函数(们)写在了下面,由上到下依次调用:

def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
    """
    Returns:
        A `Dataset` class.
    Raises:
        ValueError: If the dataset `name` is unknown.
    """
    if name not in datasets_map:
        raise ValueError('Name of dataset unknown %s' % name)
    # pascalvoc_2012.get_split
    return datasets_map[name].get_split(split_name,
                                        dataset_dir,
                                        file_pattern,
                                        reader)


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
    """
    Returns:
      A `Dataset` namedtuple.
    Raises:
        ValueError: if `split_name` is not a valid train/test split.
    """
    if not file_pattern:
        file_pattern = FILE_PATTERN  # 需要文件命名格式满足:'voc_2012_%s_*.tfrecord'
    return pascalvoc_common.get_split(split_name, dataset_dir,
                                      file_pattern, reader,
                                      SPLITS_TO_SIZES,  # {'train': 17125,}
                                      ITEMS_TO_DESCRIPTIONS,
                                      NUM_CLASSES  # 20
                                      )
    """
    ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying height and width.',
    'shape': 'Shape of the image',
    'object/bbox': 'A list of bounding boxes, one per each object.',
    'object/label': 'A list of labels, one per each object.',
    }
    """

最终调用,获取slim.dataset.Dataset(解析见『TensorFlow』从磁盘读取数据),实际上能够传入满足slim.dataset.Dataset的参数即可:

def get_split(split_name, dataset_dir, file_pattern, reader,
              split_to_sizes, items_to_descriptions, num_classes):
    """Gets a dataset tuple with instructions for reading Pascal VOC dataset.

    Args:
      split_name: A train/test split name.
      dataset_dir: The base directory of the dataset sources.
      file_pattern: The file pattern to use when matching the dataset sources.
        It is assumed that the pattern contains a '%s' string so that the split
        name can be inserted.
      reader: The TensorFlow reader type.

    Returns:
      A `Dataset` namedtuple.

    Raises:
        ValueError: if `split_name` is not a valid train/test split.
    """
    # 'train'
    if split_name not in split_to_sizes:
        raise ValueError('split name %s was not recognized.' % split_name)
    file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # Allowing None in the signature so that dataset_factory can use the default.
    if reader is None:
        reader = tf.TFRecordReader
    # Features in Pascal VOC TFRecords.
    keys_to_features = {  # 解码TFR文件方式
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature([1], tf.int64),
        'image/width': tf.FixedLenFeature([1], tf.int64),
        'image/channels': tf.FixedLenFeature([1], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = {  # 解码二进制数据条目
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }
    # 解码实施
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    labels_to_names = None
    # tf.gfile.Exists(os.path.join(dataset_dir, 'labels.txt'))
    if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)
    # else:
    #     labels_to_names = create_readable_names_for_imagenet_labels()
    #     dataset_utils.write_label_file(labels_to_names, dataset_dir)

    return slim.dataset.Dataset(
            data_sources=file_pattern,                    # TFR文件名
            reader=reader,                                # 阅读器
            decoder=decoder,                              # 解码Tensor
            num_samples=split_to_sizes[split_name],       # 数目
            items_to_descriptions=items_to_descriptions,  # decoder条目描述字段
            num_classes=num_classes,                      # 类别数
            labels_to_names=labels_to_names               # 字典{图片:类别,……}
    )

''' items_to_descriptions:
    {'image': 'A color image of varying height and width.',
     'shape': 'Shape of the image',
     'object/bbox': 'A list of bounding boxes, one per each object.',
     'object/label': 'A list of labels, one per each object.',}
'''

下面从TFR中获取1 batch的数据:

            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,  # DatasetDataProvider 需要 slim.dataset.Dataset 做参数
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.c
            # DatasetDataProvider可以通过TFR字段获取batch size数据
            [image, shape, glabels, gbboxes] = provider.get(['image', 'shape',
                                                             'object/label',
                                                             'object/bbox'])

 此时数据已经获取完毕,预处理之后即可加入运算。

注意,直到现在为止,我们仅对图片数据进行了解码,并没有扩充维度,也就是说其维度依然是3维

二、数据处理

获取对应数据集的预处里函数,并使用其处理上面小结中获取的batch数据,

image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

# Pre-processing image, labels and bboxes.
image, glabels, gbboxes = \
    image_preprocessing_fn(image, glabels, gbboxes,
                           out_shape=ssd_shape,  # (300,300)
                           data_format=DATA_FORMAT)  # 'NCHW'

 有的时候你会觉得这种层层调用非常的sb……下面两步依旧是个调用链,

def get_preprocessing(name, is_training=False):
    preprocessing_fn_map = {
        'ssd_300_vgg': ssd_vgg_preprocessing,
        'ssd_512_vgg': ssd_vgg_preprocessing,
    }

    if name not in preprocessing_fn_map:
        raise ValueError('Preprocessing name [%s] was not recognized' % name)

    def preprocessing_fn(image, labels, bboxes,
                         out_shape, data_format='NHWC', **kwargs):
        return preprocessing_fn_map[name].preprocess_image(
            image, labels, bboxes, out_shape, data_format=data_format,
            is_training=is_training, **kwargs)
    return preprocessing_fn


def preprocess_image(image,
                     labels,
                     bboxes,
                     out_shape,
                     data_format,
                     is_training=False,
                     **kwargs):
    if is_training:
        return preprocess_for_train(image, labels, bboxes,
                                    out_shape=out_shape,
                                    data_format=data_format)
    else:
        return preprocess_for_eval(image, labels, bboxes,
                                   out_shape=out_shape,
                                   data_format=data_format,
                                   **kwargs)

之后就是数据具体的预处理函数,

猜你喜欢

转载自www.cnblogs.com/hellcat/p/9341921.html