DL学习笔记-cifar10_input源码解读

在cifar10_input里面用到了多线程处理和图像的预处理,刚好和第七章的内容符合,所以我就把这个函数从前到后理了一遍。
1、先获取数据文件的列表,源码利用的是for循环构建一个文件列表,也可以用TF提供的函数,

2、判断文件是否存在,如果不存在,直接抛出错误
3、创建一个文件队列,然后从队列中读取文件内容

4、读取文件就用到了读取文件的函数,在函数里面我们把数据处理好,直接输出结果
5、对得到的图像进行翻转、色彩调整等操作,这一步有改进的空间,然后进行归一化

6、对像素矩阵调整shape,定义好参数,一个batch一个batch的输出
7、用tf.train.batch函数,可以多线程的输出一个batch的图像数据

import tensorflow as tf
import os


#定义超参
IMAGE_SIZE = 24
NUM_CLASS = 10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000

#从文件列表中的文件中读取样例
def read_image(filename_queue):
    #读取并解析图片
    #创建一个伪类 并初始化
    class CIFAR10Record(object):
        pass
    result = CIFAR10Record()

    label_bytes = 1
    result.height = 32
    result.width = 32
    result.depth = 3
    image_bytes = result.width * result.height * result.depth
    #每一个记录都是图像+标签
    record_bytes = label_bytes + image_bytes
    #定义一个reader,按照长度大小读取
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    result.key, value = reader.read(filename_queue)

    #解码 ,
    record_bytes = tf.decode_raw(value, tf.uint8)
    # record_bytes 的第一个字节表示标签,
    result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

    #把标签去掉,剩下的就是image_bytes 把[weight * height * depth] 变成 [depth, height, weight]
    depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes+image_bytes]),
                             [result.depth, result.height, result.width])

    #把[depth, height, width] 转成 [height, width, depth]
    result.uint8image = tf.transpose(depth_major, [1, 2, 0])

    return result

#构建一个batch的数据, 多线程处理
def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size, shuffle):
    '''

    :param image: 3D tensor of [height, width, depth] of type.float32
    :param label:  1D tensor of ty.int32
    :param min_queue_examples:  队列中最少保留样本数,防止shuffle操作无效
    :param batch_size: 一个batch有多少样本
    :param shuffle: 是否打乱
    :return: image和label [batch_size , height,width, 3] and [batch_size ]
    '''
    num_preprocess_threads = 16
    if shuffle:
        images, label_batch = tf.train.shuffle_batch([image, label],
                                                     batch_size=batch_size,
                                                     num_threads=num_preprocess_threads,
                                                     capacity=min_queue_examples + 3 * batch_size,
                                                     min_after_dequeue= min_queue_examples)
    else:
        image, label_batch = tf.train.batch([image, label],
                                            batch_size=batch_size,
                                            num_threads=num_preprocess_threads,
                                            capacity=min_queue_examples + 3 * batch_size)

    #将image添加到tensorboard
    tf.summary.image('images', images)

    #返回处理后的结果 images[batch_size, height,width,3]
    return images, tf.reshape(label_batch, [batch_size])

#数据预处理加读取
def distorted_inputs(data_dir, batch_size):
    '''

    :param data_dir:  数据的路径
    :param batch_size:  每个batch里图片数量
    :return:  image: 4D tensor of [batch_size, image_size, image_size , 3]
                label: 1D tensor of [batch_size]
    '''
    #读取cifar10的数据文件
    filenames =[os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
    # files = tf.train.match_filenames_once(os.path.join(data_dir, 'data_batch_*.bin'))
    # filename_queue = tf.train.string_input_producer(files)

    #判断文件是否都存在,如果有不存在的 直接抛出错误
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: '+ f)

    #创建一个文件队列
    filename_queue = tf.train.string_input_producer(filenames)

    with tf.name_scope('data_augmentation'):
        #从文件队列的文件中读取样本, 结果是[height, width, depth]
        read_input = read_image(filename_queue)
        #转成float32
        reshaped_image = tf.cast(read_input.uint8image, tf.float32)

        height = IMAGE_SIZE
        width = IMAGE_SIZE

        distored_image = tf.random_crop(reshaped_image, [height, width, 3])
        #随机翻转
        distored_image = tf.image.random_flip_left_right(distored_image)
        #随机调整亮度
        distored_image = tf.image.random_brightness(distored_image,max_delta= 63)
        #随机调整对比度
        distored_image = tf.image.random_contrast(distored_image, lower=0.2, upper= 1.8)

        #像素归一化
        float_image = tf.image.per_image_standardization(distored_image)

        #调整一下shape
        float_image.set_shape([height, width, 3])
        read_input.label.set_shape([1])

        #定义要保留的比例
        min_fraction_of_examples_in_queue = 0.4
        min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)

        print('Filling queue with %d CIFAR images before starting to train . This will take a few minutes' % min_fraction_of_examples_in_queue)

        return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle = True)


#单纯的读取数据,不打乱也不预处理
def inputs(eval_data, data_dir, batch_size):
    #读取训练数据
    if not eval_data:
        filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
        num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
    else:
        filenames = [os.path.join(data_dir, 'test_batch.bin')]
        num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)

    with tf.name_scope('input'):
        filename_queue = tf.train.string_input_producer(filenames)

        read_input = read_image(filename_queue)
        reshaped_image = tf.cast(read_input.uint8image, tf.float32)

        height = IMAGE_SIZE
        width = IMAGE_SIZE

        #将图像统一大小
        resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, height, width)
        #图像归一化
        float_iamge = tf.image.per_image_standardization(resized_image)

        float_iamge.set_shape([height, width, 3])
        read_input.label.set_shape([1])

        min_fraction_of_examples_in_queue = 0.4
        min_queue_example = int(num_examples_per_epoch * min_fraction_of_examples_in_queue)

        return _generate_image_and_label_batch(float_iamge, read_input.label, min_queue_example, batch_size, shuffle=False)

猜你喜欢

转载自blog.csdn.net/qq_36387683/article/details/80665794
今日推荐