CIFAR-10与ImageNet图像识别

2.1.2 下载CIFAR-10 数据

python cifar10_download.py
# 引入当前目录中的已经编写好的cifar10模块
import cifar10
import tensorflow as tf

# tf.app.flags.FLAGS是TensorFlow内部的一个全局变量存储器,同时可以用于命令行参数的处理
FLAGS = tf.app.flags.FLAGS

# 在cifar10模块中预先定义了f.app.flags.FLAGS.data_dir为CIFAR-10的数据路径,我们把这个路径改为cifar10_data
FLAGS.data_dir = 'cifar10_data/'

# 如果不存在数据文件,就会执行下载
cifar10.maybe_download_and_extract()

2.1.3 TensorFlow 的数据读取机制

实验脚本:

python test.py
import tensorflow as tf 
import os
if not os.path.exists('read'):
    os.makedirs('read/')

# 新建一个Session
with tf.Session() as sess:
    # 我们要读三幅图片A.jpg, B.jpg, C.jpg
    filename = ['A.jpg', 'B.jpg', 'C.jpg']
    # string_input_producer会产生一个文件名队列
    filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)
    # reader从文件名队列中读数据。对应的方法是reader.read
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)
    # tf.train.string_input_producer定义了一个epoch变量,要对它进行初始化
    tf.local_variables_initializer().run()
    # 使用start_queue_runners之后,才会开始填充队列
    threads = tf.train.start_queue_runners(sess=sess)
    i = 0
    while True:
        i += 1
        # 获取图片数据并保存
        image_data = sess.run(value)
        with open('read/test_%d.jpg' % i, 'wb') as f:
            f.write(image_data)
# 程序最后会抛出一个OutOfRangeError,这是epoch跑完,队列关闭的标志

2.1.4 实验:将CIFAR-10 数据集保存为图片形式

python cifar10_extract.py
# 导入当前目录的cifar10_input,这个模块负责读入cifar10数据
import cifar10_input
# 导入TensorFlow和其他一些可能用到的模块。
import tensorflow as tf
import os
import scipy.misc


def inputs_origin(data_dir):
    # filenames一共5个,从data_batch_1.bin到data_batch_5.bin
    # 读入的都是训练图像
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
                 for i in range(1, 6)]
    # 判断文件是否存在
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    # 将文件名的list包装成TensorFlow中queue的形式
    filename_queue = tf.train.string_input_producer(filenames)
    # cifar10_input.read_cifar10是事先写好的从queue中读取文件的函数
    # 返回的结果read_input的属性uint8image就是图像的Tensor
    read_input = cifar10_input.read_cifar10(filename_queue)
    # 将图片转换为实数形式
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)
    # 返回的reshaped_image是一张图片的tensor
    # 我们应当这样理解reshaped_image:每次使用sess.run(reshaped_image),就会取出一张图片
    return reshaped_image


if __name__ == '__main__':
    # 创建一个会话sess
    with tf.Session() as sess:
        # 调用inputs_origin。cifar10_data/cifar-10-batches-bin是我们下载的数据的文件夹位置
        reshaped_image = inputs_origin('cifar10_data/cifar-10-batches-bin')
        # 这一步start_queue_runner很重要。
        # 我们之前有filename_queue = tf.train.string_input_producer(filenames)
        # 这个queue必须通过start_queue_runners才能启动
        # 缺少start_queue_runners程序将不能执行
        threads = tf.train.start_queue_runners(sess=sess)
        # 变量初始化
        sess.run(tf.global_variables_initializer())
        # 创建文件夹cifar10_data/raw/
        if not os.path.exists('cifar10_data/raw/'):
            os.makedirs('cifar10_data/raw/')
        # 保存30张图片
        for i in range(30):
            # 每次sess.run(reshaped_image),都会取出一张图片
            image_array = sess.run(reshaped_image)
            # 将图片保存
            scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)

2.2.3 训练模型

python cifar10_train.py --train_dir cifar10_train/ --data_dir cifar10_data/
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import time

import tensorflow as tf

import cifar10

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', "Directory where to write event logs and checkpoint.")
tf.app.flags.DEFINE_integer('max_steps', 1000000, "Number of batches to run.")
tf.app.flags.DEFINE_boolean('log_device_placement', False, "Whether to log device placement.")
tf.app.flags.DEFINE_integer('log_frequency', 10, "How often to log results to the console.")


def train():
    """
    Train CIFAR-10 for a number of steps.
    :return: 
    """
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                       tf.train.NanTensorHook(loss),
                       _LoggerHook()],
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)


def main(argv=None):  # pylint: disable=unused-argument
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    train()


if __name__ == '__main__':
    tf.app.run()

2.2.4 在TensorFlow 中查看训练进度

tensorboard --logdir cifar10_train/

2.2.5 测试模型效果

python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/

使用TensorBoard查看性能验证情况:

tensorboard --logdir cifar10_eval/ --port 6007

拓展阅读

  • 关于CIFAR-10 数据集, 读者可以访问它的官方网站https://www.cs.toronto.edu/~kriz/cifar.html 了解更多细节。此外, 网站 http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html#43494641522d3130 中收集了在CIFAR-10 数据集上表 现最好的若干模型,包括这些模型对应的论文。
  • ImageNet 数据集上的表现较好的几个著名的模型是深度学习的基石, 值得仔细研读。建议先阅读下面几篇论文:ImageNet Classification with Deep Convolutional Neural Networks(AlexNet 的提出)、Very Deep Convolutional Networks for Large-Scale Image Recognition (VGGNet)、Going Deeper with Convolutions(GoogLeNet)、Deep Residual Learning for Image Recognition(ResNet)
  • 在第2.1.3 节中,简要介绍了TensorFlow的一种数据读入机制。事实上,目前在TensorFlow 中读入数据大致有三种方法:(1)用占位符(即placeholder)读入,这种方法比较简单;(2)用队列的形式建立文件到Tensor的映射;(3)用Dataset API 读入数据,Dataset API 是TensorFlow 1.3 版本新引入的一种读取数据的机制,可以参考这 篇中文教程:https://zhuanlan.zhihu.com/p/30751039

猜你喜欢

转载自www.cnblogs.com/chenxiangzhen/p/10498703.html