2.1.2 下载CIFAR-10 数据
python cifar10_download.py
# 引入当前目录中的已经编写好的cifar10模块
import cifar10
import tensorflow as tf
# tf.app.flags.FLAGS是TensorFlow内部的一个全局变量存储器,同时可以用于命令行参数的处理
FLAGS = tf.app.flags.FLAGS
# 在cifar10模块中预先定义了f.app.flags.FLAGS.data_dir为CIFAR-10的数据路径,我们把这个路径改为cifar10_data
FLAGS.data_dir = 'cifar10_data/'
# 如果不存在数据文件,就会执行下载
cifar10.maybe_download_and_extract()
2.1.3 TensorFlow 的数据读取机制
实验脚本:
python test.py
import tensorflow as tf
import os
if not os.path.exists('read'):
os.makedirs('read/')
# 新建一个Session
with tf.Session() as sess:
# 我们要读三幅图片A.jpg, B.jpg, C.jpg
filename = ['A.jpg', 'B.jpg', 'C.jpg']
# string_input_producer会产生一个文件名队列
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)
# reader从文件名队列中读数据。对应的方法是reader.read
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
# tf.train.string_input_producer定义了一个epoch变量,要对它进行初始化
tf.local_variables_initializer().run()
# 使用start_queue_runners之后,才会开始填充队列
threads = tf.train.start_queue_runners(sess=sess)
i = 0
while True:
i += 1
# 获取图片数据并保存
image_data = sess.run(value)
with open('read/test_%d.jpg' % i, 'wb') as f:
f.write(image_data)
# 程序最后会抛出一个OutOfRangeError,这是epoch跑完,队列关闭的标志
2.1.4 实验:将CIFAR-10 数据集保存为图片形式
python cifar10_extract.py
# 导入当前目录的cifar10_input,这个模块负责读入cifar10数据
import cifar10_input
# 导入TensorFlow和其他一些可能用到的模块。
import tensorflow as tf
import os
import scipy.misc
def inputs_origin(data_dir):
# filenames一共5个,从data_batch_1.bin到data_batch_5.bin
# 读入的都是训练图像
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, 6)]
# 判断文件是否存在
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# 将文件名的list包装成TensorFlow中queue的形式
filename_queue = tf.train.string_input_producer(filenames)
# cifar10_input.read_cifar10是事先写好的从queue中读取文件的函数
# 返回的结果read_input的属性uint8image就是图像的Tensor
read_input = cifar10_input.read_cifar10(filename_queue)
# 将图片转换为实数形式
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
# 返回的reshaped_image是一张图片的tensor
# 我们应当这样理解reshaped_image:每次使用sess.run(reshaped_image),就会取出一张图片
return reshaped_image
if __name__ == '__main__':
# 创建一个会话sess
with tf.Session() as sess:
# 调用inputs_origin。cifar10_data/cifar-10-batches-bin是我们下载的数据的文件夹位置
reshaped_image = inputs_origin('cifar10_data/cifar-10-batches-bin')
# 这一步start_queue_runner很重要。
# 我们之前有filename_queue = tf.train.string_input_producer(filenames)
# 这个queue必须通过start_queue_runners才能启动
# 缺少start_queue_runners程序将不能执行
threads = tf.train.start_queue_runners(sess=sess)
# 变量初始化
sess.run(tf.global_variables_initializer())
# 创建文件夹cifar10_data/raw/
if not os.path.exists('cifar10_data/raw/'):
os.makedirs('cifar10_data/raw/')
# 保存30张图片
for i in range(30):
# 每次sess.run(reshaped_image),都会取出一张图片
image_array = sess.run(reshaped_image)
# 将图片保存
scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)
2.2.3 训练模型
python cifar10_train.py --train_dir cifar10_train/ --data_dir cifar10_data/
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import time
import tensorflow as tf
import cifar10
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', "Directory where to write event logs and checkpoint.")
tf.app.flags.DEFINE_integer('max_steps', 1000000, "Number of batches to run.")
tf.app.flags.DEFINE_boolean('log_device_placement', False, "Whether to log device placement.")
tf.app.flags.DEFINE_integer('log_frequency', 10, "How often to log results to the console.")
def train():
"""
Train CIFAR-10 for a number of steps.
:return:
"""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)
# Calculate loss.
loss = cifar10.loss(logits, labels)
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
if __name__ == '__main__':
tf.app.run()
2.2.4 在TensorFlow 中查看训练进度
tensorboard --logdir cifar10_train/
2.2.5 测试模型效果
python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/
使用TensorBoard查看性能验证情况:
tensorboard --logdir cifar10_eval/ --port 6007
拓展阅读
- 关于CIFAR-10 数据集, 读者可以访问它的官方网站https://www.cs.toronto.edu/~kriz/cifar.html 了解更多细节。此外, 网站 http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html#43494641522d3130 中收集了在CIFAR-10 数据集上表 现最好的若干模型,包括这些模型对应的论文。
- ImageNet 数据集上的表现较好的几个著名的模型是深度学习的基石, 值得仔细研读。建议先阅读下面几篇论文:ImageNet Classification with Deep Convolutional Neural Networks(AlexNet 的提出)、Very Deep Convolutional Networks for Large-Scale Image Recognition (VGGNet)、Going Deeper with Convolutions(GoogLeNet)、Deep Residual Learning for Image Recognition(ResNet)
- 在第2.1.3 节中,简要介绍了TensorFlow的一种数据读入机制。事实上,目前在TensorFlow 中读入数据大致有三种方法:(1)用占位符(即placeholder)读入,这种方法比较简单;(2)用队列的形式建立文件到Tensor的映射;(3)用Dataset API 读入数据,Dataset API 是TensorFlow 1.3 版本新引入的一种读取数据的机制,可以参考这 篇中文教程:https://zhuanlan.zhihu.com/p/30751039。