一、TFR数据读取
在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dataset对象
# Select the dataset. # 'imagenet', 'train', tfr文件存储位置 # TFR文件命名格式:'voc_2012_%s_*.tfrecord',%s使用train或者test dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
获取过程会经过一系列臃肿的调用,我把中间被调用的函数(们)写在了下面,由上到下依次调用:
def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): """ Returns: A `Dataset` class. Raises: ValueError: If the dataset `name` is unknown. """ if name not in datasets_map: raise ValueError('Name of dataset unknown %s' % name) # pascalvoc_2012.get_split return datasets_map[name].get_split(split_name, dataset_dir, file_pattern, reader) def get_split(split_name, dataset_dir, file_pattern=None, reader=None): """ Returns: A `Dataset` namedtuple. Raises: ValueError: if `split_name` is not a valid train/test split. """ if not file_pattern: file_pattern = FILE_PATTERN # 需要文件命名格式满足:'voc_2012_%s_*.tfrecord' return pascalvoc_common.get_split(split_name, dataset_dir, file_pattern, reader, SPLITS_TO_SIZES, # {'train': 17125,} ITEMS_TO_DESCRIPTIONS, NUM_CLASSES # 20 ) """ ITEMS_TO_DESCRIPTIONS = { 'image': 'A color image of varying height and width.', 'shape': 'Shape of the image', 'object/bbox': 'A list of bounding boxes, one per each object.', 'object/label': 'A list of labels, one per each object.', } """
最终调用,获取slim.dataset.Dataset(解析见『TensorFlow』从磁盘读取数据),实际上能够传入满足slim.dataset.Dataset的参数即可:
def get_split(split_name, dataset_dir, file_pattern, reader, split_to_sizes, items_to_descriptions, num_classes): """Gets a dataset tuple with instructions for reading Pascal VOC dataset. Args: split_name: A train/test split name. dataset_dir: The base directory of the dataset sources. file_pattern: The file pattern to use when matching the dataset sources. It is assumed that the pattern contains a '%s' string so that the split name can be inserted. reader: The TensorFlow reader type. Returns: A `Dataset` namedtuple. Raises: ValueError: if `split_name` is not a valid train/test split. """ # 'train' if split_name not in split_to_sizes: raise ValueError('split name %s was not recognized.' % split_name) file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default. if reader is None: reader = tf.TFRecordReader # Features in Pascal VOC TFRecords. keys_to_features = { # 解码TFR文件方式 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/height': tf.FixedLenFeature([1], tf.int64), 'image/width': tf.FixedLenFeature([1], tf.int64), 'image/channels': tf.FixedLenFeature([1], tf.int64), 'image/shape': tf.FixedLenFeature([3], tf.int64), 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), 'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64), 'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64), } items_to_handlers = { # 解码二进制数据条目 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 'shape': slim.tfexample_decoder.Tensor('image/shape'), 'object/bbox': slim.tfexample_decoder.BoundingBox( ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'), 'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'), 'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'), } # 解码实施 decoder = slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) labels_to_names = None # tf.gfile.Exists(os.path.join(dataset_dir, 'labels.txt')) if dataset_utils.has_labels(dataset_dir): labels_to_names = dataset_utils.read_label_file(dataset_dir) # else: # labels_to_names = create_readable_names_for_imagenet_labels() # dataset_utils.write_label_file(labels_to_names, dataset_dir) return slim.dataset.Dataset( data_sources=file_pattern, # TFR文件名 reader=reader, # 阅读器 decoder=decoder, # 解码Tensor num_samples=split_to_sizes[split_name], # 数目 items_to_descriptions=items_to_descriptions, # decoder条目描述字段 num_classes=num_classes, # 类别数 labels_to_names=labels_to_names # 字典{图片:类别,……} ) ''' items_to_descriptions: {'image': 'A color image of varying height and width.', 'shape': 'Shape of the image', 'object/bbox': 'A list of bounding boxes, one per each object.', 'object/label': 'A list of labels, one per each object.',} '''
下面从TFR中获取1 batch的数据:
with tf.name_scope(FLAGS.dataset_name + '_data_provider'): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, # DatasetDataProvider 需要 slim.dataset.Dataset 做参数 num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size, shuffle=True) # Get for SSD network: image, labels, bboxes.c # DatasetDataProvider可以通过TFR字段获取batch size数据 [image, shape, glabels, gbboxes] = provider.get(['image', 'shape', 'object/label', 'object/bbox'])
此时数据已经获取完毕,预处理之后即可加入运算。
注意,直到现在为止,我们仅对图片数据进行了解码,并没有扩充维度,也就是说其维度依然是3维。
二、数据处理
获取对应数据集的预处里函数,并使用其处理上面小结中获取的batch数据,
image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True) # Pre-processing image, labels and bboxes. image, glabels, gbboxes = \ image_preprocessing_fn(image, glabels, gbboxes, out_shape=ssd_shape, # (300,300) data_format=DATA_FORMAT) # 'NCHW'
有的时候你会觉得这种层层调用非常的sb……下面两步依旧是个调用链,
def get_preprocessing(name, is_training=False): preprocessing_fn_map = { 'ssd_300_vgg': ssd_vgg_preprocessing, 'ssd_512_vgg': ssd_vgg_preprocessing, } if name not in preprocessing_fn_map: raise ValueError('Preprocessing name [%s] was not recognized' % name) def preprocessing_fn(image, labels, bboxes, out_shape, data_format='NHWC', **kwargs): return preprocessing_fn_map[name].preprocess_image( image, labels, bboxes, out_shape, data_format=data_format, is_training=is_training, **kwargs) return preprocessing_fn def preprocess_image(image, labels, bboxes, out_shape, data_format, is_training=False, **kwargs): if is_training: return preprocess_for_train(image, labels, bboxes, out_shape=out_shape, data_format=data_format) else: return preprocess_for_eval(image, labels, bboxes, out_shape=out_shape, data_format=data_format, **kwargs)
之后就是数据具体的预处理函数,