使用TensorFlow打造自己的图像识别模型

1.目标

2.微调原理了解

3.数据准备

 

4.使用TensorFlow Slim微调模型

1)下载TensorFlow Slim源码

2)定义自己的datasets文件

3)准备训练文件夹

4)开始训练

5)模型准确率验证

6)导出模型对单张图片进行识别

5.问题总结


本文为笔者学习《21个项目玩转深度学习:基于TensorFlow的实践详解》这本书第三章的学习笔记。

1.目标

        使用TensorFlow在自己的图像数据上训练深度学习模型。这里主要是使用已经训练好的ImageNet模型进行微调(fine-tune)。

2.微调原理了解

        对于神经网络训练中网络的层数、滤波器的大小和池化等参数的设定是要经过大量的调参实验来获得的,这个过程相当不容易。因此,大多数实际中使用的模型都是借鉴别人已经训练好的模型,比如著名的AlexNet、VGG16、VGG19、GoogleLeNet、Inception-v3和ResNet等,在此之上修改最后的全连接层参数个数来进行训练。这里的训练分为全部训练和部分训练。全部训练就是在新的数据集上完整的跑一边训练过程,部分训练就是在别人已经训练好的模型上进行微调。借助微调,可以从预训练模型触发,借助预训练模型中大量的已经训练好的卷积滤波器运用到自己的数据集上,可以达到节约训练时间和提升分类器性能的作用。

        下面借助VGG16网络来理解一下微调的原理:

上图为VGG中各个网络的结构图,将VGG16网络单独拿出来结构如下图所示:

        VGG16的结构分为卷积层+全连接层,由以上图可以看出卷积层分为5个部分共13层(途中的conv1_1、conv1_2、conv2_1、conv2_2、conv3_1、conv3_2、conv3_3、conv4_1、conv4_2、conv4_3、conv5_1、conv5_2、conv5_3),全连接层为fc6、fc7、fc8三个层。其中fc8输入的是fc7层的特征,输出是1000个分类的概率,这1000个类别就是ImageNet中1000个类别。在我们自己的数据集中,一般分类不会是1000类,所以最后边的全连接层fc8就必须去掉,重新采用复合自己数据集的全连接层作为新的fc8。比如数据集为5类,那么fc8的输出也应该为5类。

3.数据准备

1)将数据分为训练集和验证集

本实验使用的时作者书里提供的卫星拍摄图片数据集。原始数据集保存路径如下,其中train中为训练数据集,validation中为测试数据集。可以看出共有wood(树林)、water(水)、rock(岩石)、wetland(湿地)、glacier(冰川)、urban(城市)这六类图片。


  
  
  1. data_prepare/
  2. pic/
  3. train/
  4. wood/
  5. water/
  6. rock/
  7. wetland/
  8. glacier/
  9. urban/
  10. validation/
  11. wood/
  12. water/
  13. rock/
  14. wetland/
  15. glacier/
  16. urban/

图片长什么样呢?对于本例中使用的6类数据集个随意取一张来看看:

2)将数据转换为tfrecord格式数据

        对于大数据,TensorFlow中都需要转换成TFRecord格式的文件,TFRecord文件同样是以二进制进行存储数据的,适合以串行的方式读取大批量数据。其优势是能更好的利用内存,更方便地复制和移动,这更符合TensorFlow执行引擎的处理方式。通常数据转换成tfrecord格式需要写个小程序将每一个样本组装成protocol buffer定义的Example的对象,序列化成字符串,再由tf.python_io.TFRecordWriter写入文件即可。

在data_prepare文件夹下,存放有data_convert.py文件,该文件可将图片转换为tfrecord格式。data_convert.py文件代码如下:


  
  
  1. # coding:utf-8
  2. from __future__ import absolute_import
  3. import argparse
  4. import os
  5. import logging
  6. from src.tfrecord import main
  7. #默认参数初始化函数
  8. def parse_args():
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument( '-t', '--tensorflow-data-dir', default= 'pic/')
  11. parser.add_argument( '--train-shards', default= 2, type= int)
  12. parser.add_argument( '--validation-shards', default= 2, type= int)
  13. parser.add_argument( '--num-threads', default= 2, type= int)
  14. parser.add_argument( '--dataset-name', default= 'satellite', type=str)
  15. print(type(parser.parse_args()))
  16. return parser.parse_args()
  17. if __name__ == '__main__':
  18. #指令日志显示级别
  19. logging.basicConfig(level=logging.INFO)
  20. #参数设置
  21. args = parse_args()
  22. args.tensorflow_dir = args.tensorflow_data_dir
  23. args.train_directory = os.path. join(args.tensorflow_dir, 'train')
  24. args.validation_directory = os.path. join(args.tensorflow_dir, 'validation')
  25. args.output_directory = args.tensorflow_dir
  26. args.labels_file = os.path. join(args.tensorflow_dir, 'label.txt')
  27. if os.path.exists(args.labels_file) is False:
  28. logging.warning( 'Can\'t find label.txt. Now create it.')
  29. all_entries = os.listdir(args.train_directory)
  30. dirnames = []
  31. for entry in all_entries:
  32. if os.path.isdir(os.path. join(args.train_directory, entry)):
  33. dirnames.append(entry)
  34. with open(args.labels_file, 'w') as f:
  35. for dirname in dirnames:
  36. f.write(dirname + '\n')
  37. #调用tfrecord中的main函数进行转换
  38. main(args)

 

图片格式进行转换时,执行指令如下:

python data_convert.py -t pie/ \
        --train-shards 2 \
        --validation-shards 2 \
        --num-threads 2 \
        --dataset-name satellite

        执行完以上命令后,可以在pic文件夹中找到5个新生成的文件。分别是训练数据satellite_train_00000-of-00002.tfrecord、satellite_train_00001-of-00002.tfrecord和验证数据satellite_validation_00000-of-00002.tfrecord、satellite_validation_00001-of-00002.tfrecord,另外还有个文本文件label.txt,表示图片的内部标签(数字)到真实类别(字符串)之间的映射顺序 。 如图片在 tfrecord 中的标签为 0 ,那么就对应 label.txt 第一行的类别,在 tfrecord的标签为1 ,就对应 label.txt 中第二行的类别,依此类推 。
 

        以上文件data_convert.py最终调用的main函数存在于tfrecord.py中,顺着main函数可以看到tfrecord转换中都做了哪些操作,该文件代码如下:


  
  
  1. # coding:utf-8
  2. # Copyright 2016 Google Inc. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ==============================================================================
  16. """Converts image data to TFRecords file format with Example protos.
  17. The image data set is expected to reside in JPEG files located in the
  18. following directory structure.
  19. data_dir/label_0/image0.jpeg
  20. data_dir/label_0/image1.jpg
  21. ...
  22. data_dir/label_1/weird-image.jpeg
  23. data_dir/label_1/my-image.jpeg
  24. ...
  25. where the sub-directory is the unique label associated with these images.
  26. This TensorFlow script converts the training and evaluation data into
  27. a sharded data set consisting of TFRecord files
  28. train_directory/train-00000-of-01024
  29. train_directory/train-00001-of-01024
  30. ...
  31. train_directory/train-00127-of-01024
  32. and
  33. validation_directory/validation-00000-of-00128
  34. validation_directory/validation-00001-of-00128
  35. ...
  36. validation_directory/validation-00127-of-00128
  37. where we have selected 1024 and 128 shards for each data set. Each record
  38. within the TFRecord file is a serialized Example proto. The Example proto
  39. contains the following fields:
  40. image/encoded: string containing JPEG encoded image in RGB colorspace
  41. image/height: integer, image height in pixels
  42. image/width: integer, image width in pixels
  43. image/colorspace: string, specifying the colorspace, always 'RGB'
  44. image/channels: integer, specifying the number of channels, always 3
  45. image/format: string, specifying the format, always'JPEG'
  46. image/filename: string containing the basename of the image file
  47. e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG'
  48. image/class/label: integer specifying the index in a classification layer. start from "class_label_base"
  49. image/class/text: string specifying the human-readable version of the label
  50. e.g. 'dog'
  51. If you data set involves bounding boxes, please look at build_imagenet_data.py.
  52. """
  53. from __future__ import absolute_import
  54. from __future__ import division
  55. from __future__ import print_function
  56. from datetime import datetime
  57. import os
  58. import random
  59. import sys
  60. import threading
  61. import numpy as np
  62. import tensorflow as tf
  63. import logging
  64. def _int64_feature(value):
  65. """Wrapper for inserting int64 features into Example proto."""
  66. if not isinstance(value, list):
  67. value = [value]
  68. return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
  69. def _bytes_feature(value):
  70. """Wrapper for inserting bytes features into Example proto."""
  71. value = tf.compat.as_bytes(value)
  72. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  73. def _convert_to_example(filename, image_buffer, label, text, height, width):
  74. """Build an Example proto for an example.
  75. Args:
  76. filename: string, path to an image file, e.g., '/path/to/example.JPG'
  77. image_buffer: string, JPEG encoding of RGB image
  78. label: integer, identifier for the ground truth for the network
  79. text: string, unique human-readable, e.g. 'dog'
  80. height: integer, image height in pixels
  81. width: integer, image width in pixels
  82. Returns:
  83. Example proto
  84. """
  85. colorspace = 'RGB'
  86. channels = 3
  87. image_format = 'JPEG'
  88. example = tf.train.Example(features=tf.train.Features(feature={
  89. 'image/height': _int64_feature(height),
  90. 'image/width': _int64_feature(width),
  91. 'image/colorspace': _bytes_feature(colorspace),
  92. 'image/channels': _int64_feature(channels),
  93. 'image/class/label': _int64_feature(label),
  94. 'image/class/text': _bytes_feature(text),
  95. 'image/format': _bytes_feature(image_format),
  96. 'image/filename': _bytes_feature(os.path.basename(filename)),
  97. 'image/encoded': _bytes_feature(image_buffer)}))
  98. return example
  99. class ImageCoder(object):
  100. """Helper class that provides TensorFlow image coding utilities."""
  101. def __init__(self):
  102. # Create a single Session to run all image coding calls.
  103. self._sess = tf.Session()
  104. # Initializes function that converts PNG to JPEG data.
  105. self._png_data = tf.placeholder(dtype=tf.string)
  106. image = tf.image.decode_png(self._png_data, channels= 3)
  107. self._png_to_jpeg = tf.image.encode_jpeg(image, format= 'rgb', quality= 100)
  108. # Initializes function that decodes RGB JPEG data.
  109. self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
  110. self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels= 3)
  111. def png_to_jpeg(self, image_data):
  112. return self._sess.run(self._png_to_jpeg,
  113. feed_dict={self._png_data: image_data})
  114. def decode_jpeg(self, image_data):
  115. image = self._sess.run(self._decode_jpeg,
  116. feed_dict={self._decode_jpeg_data: image_data})
  117. assert len(image.shape) == 3
  118. assert image.shape[ 2] == 3
  119. return image
  120. def _is_png(filename):
  121. """Determine if a file contains a PNG format image.
  122. Args:
  123. filename: string, path of the image file.
  124. Returns:
  125. boolean indicating if the image is a PNG.
  126. """
  127. return '.png' in filename
  128. def _process_image(filename, coder):
  129. """Process a single image file.
  130. Args:
  131. filename: string, path to an image file e.g., '/path/to/example.JPG'.
  132. coder: instance of ImageCoder to provide TensorFlow image coding utils.
  133. Returns:
  134. image_buffer: string, JPEG encoding of RGB image.
  135. height: integer, image height in pixels.
  136. width: integer, image width in pixels.
  137. """
  138. # Read the image file.
  139. with open(filename, 'rb') as f:
  140. image_data = f.read()
  141. # Convert any PNG to JPEG's for consistency.
  142. if _is_png(filename):
  143. logging.info( 'Converting PNG to JPEG for %s' % filename)
  144. image_data = coder.png_to_jpeg(image_data)
  145. # Decode the RGB JPEG.
  146. image = coder.decode_jpeg(image_data)
  147. # Check that image converted to RGB
  148. assert len(image.shape) == 3
  149. height = image.shape[ 0]
  150. width = image.shape[ 1]
  151. assert image.shape[ 2] == 3
  152. return image_data, height, width
  153. def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
  154. texts, labels, num_shards, command_args):
  155. """Processes and saves list of images as TFRecord in 1 thread.
  156. Args:
  157. coder: instance of ImageCoder to provide TensorFlow image coding utils.
  158. thread_index: integer, unique batch to run index is within [0, len(ranges)).
  159. ranges: list of pairs of integers specifying ranges of each batches to
  160. analyze in parallel.
  161. name: string, unique identifier specifying the data set
  162. filenames: list of strings; each string is a path to an image file
  163. texts: list of strings; each string is human readable, e.g. 'dog'
  164. labels: list of integer; each integer identifies the ground truth
  165. num_shards: integer number of shards for this data set.
  166. """
  167. # Each thread produces N shards where N = int(num_shards / num_threads).
  168. # For instance, if num_shards = 128, and the num_threads = 2, then the first
  169. # thread would produce shards [0, 64).
  170. num_threads = len(ranges)
  171. assert not num_shards % num_threads
  172. num_shards_per_batch = int(num_shards / num_threads)
  173. shard_ranges = np.linspace(ranges[thread_index][ 0],
  174. ranges[thread_index][ 1],
  175. num_shards_per_batch + 1).astype(int)
  176. num_files_in_thread = ranges[thread_index][ 1] - ranges[thread_index][ 0]
  177. counter = 0
  178. for s in range(num_shards_per_batch):
  179. # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
  180. shard = thread_index * num_shards_per_batch + s
  181. output_filename = '%s_%s_%.5d-of-%.5d.tfrecord' % (command_args.dataset_name, name, shard, num_shards)
  182. output_file = os.path.join(command_args.output_directory, output_filename)
  183. writer = tf.python_io.TFRecordWriter(output_file)
  184. shard_counter = 0
  185. files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
  186. for i in files_in_shard:
  187. filename = filenames[i]
  188. label = labels[i]
  189. text = texts[i]
  190. image_buffer, height, width = _process_image(filename, coder)
  191. example = _convert_to_example(filename, image_buffer, label,
  192. text, height, width)
  193. writer.write(example.SerializeToString())
  194. shard_counter += 1
  195. counter += 1
  196. if not counter % 1000:
  197. logging.info( '%s [thread %d]: Processed %d of %d images in thread batch.' %
  198. (datetime.now(), thread_index, counter, num_files_in_thread))
  199. sys.stdout.flush()
  200. writer.close()
  201. logging.info( '%s [thread %d]: Wrote %d images to %s' %
  202. (datetime.now(), thread_index, shard_counter, output_file))
  203. sys.stdout.flush()
  204. shard_counter = 0
  205. logging.info( '%s [thread %d]: Wrote %d images to %d shards.' %
  206. (datetime.now(), thread_index, counter, num_files_in_thread))
  207. sys.stdout.flush()
  208. def _process_image_files(name, filenames, texts, labels, num_shards, command_args):
  209. """Process and save list of images as TFRecord of Example protos.
  210. Args:
  211. name: string, unique identifier specifying the data set
  212. filenames: list of strings; each string is a path to an image file
  213. texts: list of strings; each string is human readable, e.g. 'dog'
  214. labels: list of integer; each integer identifies the ground truth
  215. num_shards: integer number of shards for this data set.
  216. """
  217. assert len(filenames) == len(texts)
  218. assert len(filenames) == len(labels)
  219. # Break all images into batches with a [ranges[i][0], ranges[i][1]].
  220. spacing = np.linspace( 0, len(filenames), command_args.num_threads + 1).astype(np.int)
  221. ranges = []
  222. for i in range(len(spacing) - 1):
  223. ranges.append([spacing[i], spacing[i + 1]])
  224. # Launch a thread for each batch.
  225. logging.info( 'Launching %d threads for spacings: %s' % (command_args.num_threads, ranges))
  226. sys.stdout.flush()
  227. # Create a mechanism for monitoring when all threads are finished.
  228. coord = tf.train.Coordinator()
  229. # Create a generic TensorFlow-based utility for converting all image codings.
  230. coder = ImageCoder()
  231. threads = []
  232. for thread_index in range(len(ranges)):
  233. args = (coder, thread_index, ranges, name, filenames,
  234. texts, labels, num_shards, command_args)
  235. t = threading.Thread(target=_process_image_files_batch, args=args)
  236. t.start()
  237. threads.append(t)
  238. # Wait for all the threads to terminate.
  239. coord.join(threads)
  240. logging.info( '%s: Finished writing all %d images in data set.' %
  241. (datetime.now(), len(filenames)))
  242. sys.stdout.flush()
  243. def _find_image_files(data_dir, labels_file, command_args):
  244. """Build a list of all images files and labels in the data set.
  245. Args:
  246. data_dir: string, path to the root directory of images.
  247. Assumes that the image data set resides in JPEG files located in
  248. the following directory structure.
  249. data_dir/dog/another-image.JPEG
  250. data_dir/dog/my-image.jpg
  251. where 'dog' is the label associated with these images.
  252. labels_file: string, path to the labels file.
  253. The list of valid labels are held in this file. Assumes that the file
  254. contains entries as such:
  255. dog
  256. cat
  257. flower
  258. where each line corresponds to a label. We map each label contained in
  259. the file to an integer starting with the integer 0 corresponding to the
  260. label contained in the first line.
  261. Returns:
  262. filenames: list of strings; each string is a path to an image file.
  263. texts: list of strings; each string is the class, e.g. 'dog'
  264. labels: list of integer; each integer identifies the ground truth.
  265. """
  266. logging.info( 'Determining list of input files and labels from %s.' % data_dir)
  267. unique_labels = [l.strip() for l in tf.gfile.FastGFile(
  268. labels_file, 'r').readlines()]
  269. labels = []
  270. filenames = []
  271. texts = []
  272. # Leave label index 0 empty as a background class.
  273. """非常重要,这里我们调整label从0开始以符合定义"""
  274. label_index = command_args.class_label_base
  275. # Construct the list of JPEG files and labels.
  276. for text in unique_labels:
  277. jpeg_file_path = '%s/%s/*' % (data_dir, text)
  278. matching_files = tf.gfile.Glob(jpeg_file_path)
  279. labels.extend([label_index] * len(matching_files))
  280. texts.extend([text] * len(matching_files))
  281. filenames.extend(matching_files)
  282. if not label_index % 100:
  283. logging.info( 'Finished finding files in %d of %d classes.' % (
  284. label_index, len(labels)))
  285. label_index += 1
  286. # Shuffle the ordering of all image files in order to guarantee
  287. # random ordering of the images with respect to label in the
  288. # saved TFRecord files. Make the randomization repeatable.
  289. shuffled_index = list(range(len(filenames)))
  290. random.seed( 12345)
  291. random.shuffle(shuffled_index)
  292. filenames = [filenames[i] for i in shuffled_index]
  293. texts = [texts[i] for i in shuffled_index]
  294. labels = [labels[i] for i in shuffled_index]
  295. logging.info( 'Found %d JPEG files across %d labels inside %s.' %
  296. (len(filenames), len(unique_labels), data_dir))
  297. # print(labels)
  298. return filenames, texts, labels
  299. def _process_dataset(name, directory, num_shards, labels_file, command_args):
  300. """Process a complete data set and save it as a TFRecord.
  301. Args:
  302. name: string, unique identifier specifying the data set.
  303. directory: string, root path to the data set.
  304. num_shards: integer number of shards for this data set.
  305. labels_file: string, path to the labels file.
  306. """
  307. filenames, texts, labels = _find_image_files(directory, labels_file, command_args)
  308. _process_image_files(name, filenames, texts, labels, num_shards, command_args)
  309. def check_and_set_default_args(command_args):
  310. if not(hasattr(command_args, 'train_shards')) or command_args.train_shards is None:
  311. command_args.train_shards = 5
  312. if not(hasattr(command_args, 'validation_shards')) or command_args.validation_shards is None:
  313. command_args.validation_shards = 5
  314. if not(hasattr(command_args, 'num_threads')) or command_args.num_threads is None:
  315. command_args.num_threads = 5
  316. if not(hasattr(command_args, 'class_label_base')) or command_args.class_label_base is None:
  317. command_args.class_label_base = 0
  318. if not(hasattr(command_args, 'dataset_name')) or command_args.dataset_name is None:
  319. command_args.dataset_name = ''
  320. assert not command_args.train_shards % command_args.num_threads, (
  321. 'Please make the command_args.num_threads commensurate with command_args.train_shards')
  322. assert not command_args.validation_shards % command_args.num_threads, (
  323. 'Please make the command_args.num_threads commensurate with '
  324. 'command_args.validation_shards')
  325. assert command_args.train_directory is not None
  326. assert command_args.validation_directory is not None
  327. assert command_args.labels_file is not None
  328. assert command_args.output_directory is not None
  329. def main(command_args):
  330. """
  331. command_args:需要有以下属性:
  332. command_args.train_directory 训练集所在的文件夹。这个文件夹下面,每个文件夹的名字代表label名称,再下面就是图片。
  333. command_args.validation_directory 验证集所在的文件夹。这个文件夹下面,每个文件夹的名字代表label名称,再下面就是图片。
  334. command_args.labels_file 一个文件。每一行代表一个label名称。
  335. command_args.output_directory 一个文件夹,表示最后输出的位置。
  336. command_args.train_shards 将训练集分成多少份。
  337. command_args.validation_shards 将验证集分成多少份。
  338. command_args.num_threads 线程数。必须是上面两个参数的约数。
  339. command_args.class_label_base 很重要!真正的tfrecord中,每个class的label号从多少开始,默认为0(在models/slim中就是从0开始的)
  340. command_args.dataset_name 字符串,输出的时候的前缀。
  341. 图片不可以有损坏。否则会导致线程提前退出。
  342. """
  343. check_and_set_default_args(command_args)
  344. logging.info( 'Saving results to %s' % command_args.output_directory)
  345. # Run it!
  346. _process_dataset( 'validation', command_args.validation_directory,
  347. command_args.validation_shards, command_args.labels_file, command_args)
  348. _process_dataset( 'train', command_args.train_directory,
  349. command_args.train_shards, command_args.labels_file, command_args)

 

4.使用TensorFlow Slim微调模型

    TensorFlow Slim 是 Google 公司公布的一个图像分类工具包,巴不仅定义了一些方便的接口,还提供了很多 ImageNet 数据集上常用的网络结构和预训练模型 。 截至 2017 年 7 月, Slim 提供包括 VGG16、VGG19、InceptionVl ~ V4、Resi也t 50、ResNet101、MobileNet 在内大多数常用模型的结构以及预训练模型,更多的模型还会被持续添加进来 。
 

1)下载TensorFlow Slim源码

Slim中最主要的代码结构和描述介绍如下:

本实验采用作者书中所带的下载好的TensorFlow slim代码。

2)定义自己的datasets文件

为了使用前边创建的tfrecord数据进行训练,必须在datasets中定义新的数据库。在 datasets/目录下新建一个文件 satellite.py,并将 flowers.py 文件中的内容复制到 satellite.py 中 。 接下来,需要修改satellite.py文件中以下几处:

第一部分:FILE_PATTERN、SPLITS_TO_ SIZES 、NUM_CLASSES进行以下修改:


  
  
  1. _FILE_PATTERN = 'satellite_%s_*.tfrecord'
  2. SPLITS_TO_SIZES = { 'train': 4800, 'validation': 1200}
  3. _NUM_CLASSES = 6

第二部分:修改image/format如下。此处为图片的默认格式,因为提供的卫星图片为jpg格式,因此需要进行修改。

'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
  
  

修改完satellite.py文件后,需要在同目录的 dataset_factory. py 文件中注册satellite 数据库。注册完后代码如下,新增加了from datasets import satellite和'satellite': satellite,两行代码。


  
  
  1. from datasets import cifar10
  2. from datasets import flowers
  3. from datasets import imagenet
  4. from datasets import mnist
  5. from datasets import satellite
  6. datasets_map = {
  7. 'cifar10': cifar10,
  8. 'flowers': flowers,
  9. 'imagenet': imagenet,
  10. 'mnist': mnist,
  11. 'satellite': satellite,
  12. }

3)准备训练文件夹

定义完数据集后,在slim文件夹下再新建一个satellite目录,在这个目录中完成下面几项工作:

第一项:新建一个data目录,并将3中转换好格式的数据复制进去。

第二项:新建一个train_dir目录,用来保存训练过程中的日志和模型。

第三项:新建一个pretrained目录,在 slim 的 GitHub 页面找到 Inception V3 模型的下载地址 http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz,下载并解压后,会得到一个 inception_v3 .ckpt 文件,将该文
件复制到 pretrained 目录下。特别注意:国内网如果不能翻墙的话,可能下载不了,我自己是直接在CSDN里搜索inception_v3 .ckpt文件下载的。

以上步骤操作完后,会形成目录结构如下:


  
  
  1. slim/
  2.     satellite/
  3. data/
  4. satellite_train_00000- of- 00002.tfrecord
  5. satellite_train_00001- of- 00002.tfrecord
  6. satellite_validation_00000- of- 00002.tfrecord
  7. satellite_validation_00001- of- 00002.tfrecord
  8. label.txt
  9. pretrained/
  10. inception_v3.ckpt
  11. train dir/

4)开始训练

在slim文件夹下运行以下命令进行训练:


  
  
  1. python train_image_classifier.py \
  2. --train_dir=satellite/train_dir \
  3. --dataset_name=satellite \
  4. --dataset_split_name=train \
  5. --dataset_dir=satellite/data \
  6. --model_name=inception_v3 \
  7. --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  8. --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  9. --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  10. --max_number_of_steps=100000 \
  11. --batch_size=32 \
  12. --learning_rate=0.001 \
  13. --learning_rate_decay_type=fixed \
  14. --save_interval_secs=300 \
  15. --save_summaries_secs=2 \
  16. --log_every_n_steps=10 \
  17. --optimizer=rmsprop \
  18. --weight_decay=0.00004

以上只是训练末端层InceptionV3/Logits,InceptionV3/AuxLogits ,还可以使用以下命令对所有层进行训练:


  
  
  1. python train_image_classifier.py \
  2. --train_dir=satellite/train_dir \
  3. --dataset_name=satellite \
  4. --dataset_split_name=train \
  5. --dataset_dir=satellite/data \
  6. --model_name=inception_v3 \
  7. --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  8. --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  9. --max_number_of_steps=l00000 \
  10. --batch_size=32 \
  11. --learning_rate=0.001 \
  12. --learning_rate_decay_type=fixed \
  13. --save_interval_secs=300 \
  14. --save_summaries_secs=l0 \
  15. --log_every_n_steps=1 \
  16. --optimizer=rmsprop \
  17. --weight_decay=0.00004

此处运行后要特别注意的一个问题是,slim训练默认是在GPU上进行的,对应代码在train_image_classifier.py中如下所示:


  
  
  1. tf.app.flags.DEFINE_boolean( 'clone_on_cpu', False,
  2. 'Use CPUs to deploy clones.')

        要想在CPU上运行slim程序,有两种方式:第一种是修改上边的代码将clone_on_cpu后的False改为True,第二种是在执行命令的时候,后边加上参数--clone_on_cpu=True,表示在CPU上运行。另外,在不支持GPU的设备上运行时会报下面错误,此时按照以上方式修改就可以在CPU上运行。

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation 'InceptionV3/Conv2d_1a_3x3/weights/Initializer/truncated_normal/TruncatedNormal': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/device:GPU:0'
  
  

值得关注的是如果在CPU上执行上边的训练指令,会非常耗时。

所以我将--max_number_of_steps=100000改为--max_number_of_steps=300来训练。

5)模型准确率验证

使用eval_image_classifier.py程序验证模型在验证数据集上的准确率,执行以下指令:


  
  
  1. python eval_image_classifier.py \
  2. --checkpoint_path=satellite/train_dir \
  3. --eval_dir=satellite/eval_dir \
  4. --dataset_name=satellite \
  5. --dataset_split_name=validation \
  6. --dataset_dir=satellite/data \
  7. --model_name=inception_v3

由于本人的设备不支持GPU,所以只在CPU上训练了300次。在此模型上验证准确率,结果如下。


  
  
  1. INFO:tensorflow:Evaluation [ 1/ 12]
  2. INFO:tensorflow:Evaluation [ 2/ 12]
  3. INFO:tensorflow:Evaluation [ 3/ 12]
  4. INFO:tensorflow:Evaluation [ 4/ 12]
  5. INFO:tensorflow:Evaluation [ 5/ 12]
  6. INFO:tensorflow:Evaluation [ 6/ 12]
  7. INFO:tensorflow:Evaluation [ 7/ 12]
  8. INFO:tensorflow:Evaluation [ 8/ 12]
  9. INFO:tensorflow:Evaluation [ 9/ 12]
  10. INFO:tensorflow:Evaluation [ 10/ 12]
  11. INFO:tensorflow:Evaluation [ 11/ 12]
  12. INFO:tensorflow:Evaluation [ 12/ 12]
  13. eval/Recall_5[ 0.9825] eval/Accuracy[ 0.6625]

Accuracy 表示模型的分类准确率,而 Recall_5 表示 Top 5 的准确率,即在输出的各类别概率中,正确的类别只要落在前 5 个就算对。由于此处的类别数比较少,因此可以不执行 Top 5 的准确率,民而执行 Top 2 或者 Top 3的准确率,只要在 eval_image_classifier.py 中修改下面的部分就可以了 :


  
  
  1. # Define the metrics:
  2. names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
  3. 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
  4. 'Recall_5': slim.metrics.streaming_recall_at_k(
  5. logits, labels, 5),
  6. })

eval_image_classifier.py代码如下:


  
  
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Generic evaluation script that evaluates a model using a given dataset."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import math
  20. import tensorflow as tf
  21. from datasets import dataset_factory
  22. from nets import nets_factory
  23. from preprocessing import preprocessing_factory
  24. slim = tf.contrib.slim
  25. tf.app.flags.DEFINE_integer(
  26. 'batch_size', 100, 'The number of samples in each batch.')
  27. tf.app.flags.DEFINE_integer(
  28. 'max_num_batches', None,
  29. 'Max number of batches to evaluate by default use all.')
  30. tf.app.flags.DEFINE_string(
  31. 'master', '', 'The address of the TensorFlow master to use.')
  32. tf.app.flags.DEFINE_string(
  33. 'checkpoint_path', '/tmp/tfmodel/',
  34. 'The directory where the model was written to or an absolute path to a '
  35. 'checkpoint file.')
  36. tf.app.flags.DEFINE_string(
  37. 'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
  38. tf.app.flags.DEFINE_integer(
  39. 'num_preprocessing_threads', 4,
  40. 'The number of threads used to create the batches.')
  41. tf.app.flags.DEFINE_string(
  42. 'dataset_name', 'imagenet', 'The name of the dataset to load.')
  43. tf.app.flags.DEFINE_string(
  44. 'dataset_split_name', 'test', 'The name of the train/test split.')
  45. tf.app.flags.DEFINE_string(
  46. 'dataset_dir', None, 'The directory where the dataset files are stored.')
  47. tf.app.flags.DEFINE_integer(
  48. 'labels_offset', 0,
  49. 'An offset for the labels in the dataset. This flag is primarily used to '
  50. 'evaluate the VGG and ResNet architectures which do not use a background '
  51. 'class for the ImageNet dataset.')
  52. tf.app.flags.DEFINE_string(
  53. 'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
  54. tf.app.flags.DEFINE_string(
  55. 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
  56. 'as `None`, then the model_name flag is used.')
  57. tf.app.flags.DEFINE_float(
  58. 'moving_average_decay', None,
  59. 'The decay to use for the moving average.'
  60. 'If left as None, then moving averages are not used.')
  61. tf.app.flags.DEFINE_integer(
  62. 'eval_image_size', None, 'Eval image size')
  63. FLAGS = tf.app.flags.FLAGS
  64. def main(_):
  65. if not FLAGS.dataset_dir:
  66. raise ValueError( 'You must supply the dataset directory with --dataset_dir')
  67. tf.logging.set_verbosity(tf.logging.INFO)
  68. with tf.Graph().as_default():
  69. tf_global_step = slim.get_or_create_global_step()
  70. ######################
  71. # Select the dataset #
  72. ######################
  73. dataset = dataset_factory.get_dataset(
  74. FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
  75. ####################
  76. # Select the model #
  77. ####################
  78. network_fn = nets_factory.get_network_fn(
  79. FLAGS.model_name,
  80. num_classes=(dataset.num_classes - FLAGS.labels_offset),
  81. is_training= False)
  82. ##############################################################
  83. # Create a dataset provider that loads data from the dataset #
  84. ##############################################################
  85. provider = slim.dataset_data_provider.DatasetDataProvider(
  86. dataset,
  87. shuffle= False,
  88. common_queue_capacity= 2 * FLAGS.batch_size,
  89. common_queue_min=FLAGS.batch_size)
  90. [image, label] = provider.get([ 'image', 'label'])
  91. label -= FLAGS.labels_offset
  92. #####################################
  93. # Select the preprocessing function #
  94. #####################################
  95. preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
  96. image_preprocessing_fn = preprocessing_factory.get_preprocessing(
  97. preprocessing_name,
  98. is_training= False)
  99. eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
  100. image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
  101. images, labels = tf.train.batch(
  102. [image, label],
  103. batch_size=FLAGS.batch_size,
  104. num_threads=FLAGS.num_preprocessing_threads,
  105. capacity= 5 * FLAGS.batch_size)
  106. ####################
  107. # Define the model #
  108. ####################
  109. logits, _ = network_fn(images)
  110. if FLAGS.moving_average_decay:
  111. variable_averages = tf.train.ExponentialMovingAverage(
  112. FLAGS.moving_average_decay, tf_global_step)
  113. variables_to_restore = variable_averages.variables_to_restore(
  114. slim.get_model_variables())
  115. variables_to_restore[tf_global_step.op.name] = tf_global_step
  116. else:
  117. variables_to_restore = slim.get_variables_to_restore()
  118. predictions = tf.argmax(logits, 1)
  119. labels = tf.squeeze(labels)
  120. # Define the metrics:
  121. names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
  122. 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
  123. 'Recall_5': slim.metrics.streaming_recall_at_k(
  124. logits, labels, 5),
  125. })
  126. # Print the summaries to screen.
  127. for name, value in names_to_values.items():
  128. summary_name = 'eval/%s' % name
  129. op = tf.summary.scalar(summary_name, value, collections=[])
  130. op = tf.Print(op, [value], summary_name)
  131. tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
  132. # TODO(sguada) use num_epochs=1
  133. if FLAGS.max_num_batches:
  134. num_batches = FLAGS.max_num_batches
  135. else:
  136. # This ensures that we make a single pass over all of the data.
  137. num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
  138. if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
  139. checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
  140. else:
  141. checkpoint_path = FLAGS.checkpoint_path
  142. tf.logging.info( 'Evaluating %s' % checkpoint_path)
  143. slim.evaluation.evaluate_once(
  144. master=FLAGS.master,
  145. checkpoint_path=checkpoint_path,
  146. logdir=FLAGS.eval_dir,
  147. num_evals=num_batches,
  148. eval_op=list(names_to_updates.values()),
  149. variables_to_restore=variables_to_restore)
  150. # slim.evaluation.evaluation_loop(
  151. # master=FLAGS.master,
  152. # checkpoint_dir=FLAGS.checkpoint_path,
  153. # logdir=FLAGS.eval_dir,
  154. # num_evals=num_batches,
  155. # eval_op=list(names_to_updates.values()),
  156. # variables_to_restore=variables_to_restore,
  157. # eval_interval_secs=300
  158. # )
  159. if __name__ == '__main__':
  160. tf.app.run()

6)导出模型对单张图片进行识别

模型训练完成后,紧接着就是导出训练模型,并用该模型对图片进行预测。此处提供了freeze_graph.py用于导出识别的模型,classify_image_inception_v3.py是使用inception_v3模型对单张图片进行识别的脚本。

导出模型:

TensorFlow Slim提供了导出网络结构的脚本export_inference_graph.py 。 首先在 slim 文件夹下运行指令:


  
  
  1. python export_inference_graph.py \
  2. --alsologtostderr \
  3. --model_name=inception_v3 \
  4. --output_file=satellite/inception_v3_inf_graph.pb \
  5. --dataset_name satellite

这个命令会在 satellite 文件夹中生成一个 inception_v3_inf_graph.pb 文件 。

export_inference_grap.py代码如下:


  
  
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Saves out a GraphDef containing the architecture of the model.
  16. To use it, run something like this, with a model name defined by slim:
  17. bazel build tensorflow_models/slim:export_inference_graph
  18. bazel-bin/tensorflow_models/slim/export_inference_graph \
  19. --model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb
  20. If you then want to use the resulting model with your own or pretrained
  21. checkpoints as part of a mobile model, you can run freeze_graph to get a graph
  22. def with the variables inlined as constants using:
  23. bazel build tensorflow/python/tools:freeze_graph
  24. bazel-bin/tensorflow/python/tools/freeze_graph \
  25. --input_graph=/tmp/inception_v3_inf_graph.pb \
  26. --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
  27. --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
  28. --output_node_names=InceptionV3/Predictions/Reshape_1
  29. The output node names will vary depending on the model, but you can inspect and
  30. estimate them using the summarize_graph tool:
  31. bazel build tensorflow/tools/graph_transforms:summarize_graph
  32. bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
  33. --in_graph=/tmp/inception_v3_inf_graph.pb
  34. To run the resulting graph in C++, you can look at the label_image sample code:
  35. bazel build tensorflow/examples/label_image:label_image
  36. bazel-bin/tensorflow/examples/label_image/label_image \
  37. --image=${HOME}/Pictures/flowers.jpg \
  38. --input_layer=input \
  39. --output_layer=InceptionV3/Predictions/Reshape_1 \
  40. --graph=/tmp/frozen_inception_v3.pb \
  41. --labels=/tmp/imagenet_slim_labels.txt \
  42. --input_mean=0 \
  43. --input_std=255 \
  44. --logtostderr
  45. """
  46. from __future__ import absolute_import
  47. from __future__ import division
  48. from __future__ import print_function
  49. import tensorflow as tf
  50. from tensorflow.python.platform import gfile
  51. from datasets import dataset_factory
  52. from nets import nets_factory
  53. slim = tf.contrib.slim
  54. tf.app.flags.DEFINE_string(
  55. 'model_name', 'inception_v3', 'The name of the architecture to save.')
  56. tf.app.flags.DEFINE_boolean(
  57. 'is_training', False,
  58. 'Whether to save out a training-focused version of the model.')
  59. tf.app.flags.DEFINE_integer(
  60. 'default_image_size', 224,
  61. 'The image size to use if the model does not define it.')
  62. tf.app.flags.DEFINE_string( 'dataset_name', 'imagenet',
  63. 'The name of the dataset to use with the model.')
  64. tf.app.flags.DEFINE_integer(
  65. 'labels_offset', 0,
  66. 'An offset for the labels in the dataset. This flag is primarily used to '
  67. 'evaluate the VGG and ResNet architectures which do not use a background '
  68. 'class for the ImageNet dataset.')
  69. tf.app.flags.DEFINE_string(
  70. 'output_file', '', 'Where to save the resulting file to.')
  71. tf.app.flags.DEFINE_string(
  72. 'dataset_dir', '', 'Directory to save intermediate dataset files to')
  73. FLAGS = tf.app.flags.FLAGS
  74. def main(_):
  75. if not FLAGS.output_file:
  76. raise ValueError( 'You must supply the path to save to with --output_file')
  77. tf.logging.set_verbosity(tf.logging.INFO)
  78. with tf.Graph().as_default() as graph:
  79. dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'validation',
  80. FLAGS.dataset_dir)
  81. network_fn = nets_factory.get_network_fn(
  82. FLAGS.model_name,
  83. num_classes=(dataset.num_classes - FLAGS.labels_offset),
  84. is_training=FLAGS.is_training)
  85. if hasattr(network_fn, 'default_image_size'):
  86. image_size = network_fn.default_image_size
  87. else:
  88. image_size = FLAGS.default_image_size
  89. placeholder = tf.placeholder(name= 'input', dtype=tf.float32,
  90. shape=[ 1, image_size, image_size, 3])
  91. network_fn(placeholder)
  92. graph_def = graph.as_graph_def()
  93. with gfile.GFile(FLAGS.output_file, 'wb') as f:
  94. f.write(graph_def.SerializeToString())
  95. if __name__ == '__main__':
  96. tf.app.run()

注意: inception_v3_inf_graph.pb 文件中只保存了 Inception V3 的网络结构,并不包含训练得到的模型参数,需要将 checkpoint 中的模型参数保存进来。方法是使用 freeze_graph.py 脚本(在书中有提供该文件),在freeze_graph.py所在的目录下执行以下指令:


  
  
  1. python freeze_graph.py \
  2. --input_graph slim/satellite/inception_v3_inf_graph.pb \
  3. --input_checkpoint slim/satellite/train_dir/model.ckpt-300 \
  4. --input_binary true \
  5. --output_node_names InceptionV3/Predictions/Reshape_1 \
  6. --output_graph slim/satellite/frozen_graph.pb

最后导出的模型文件如下:

预测图片:

如何使用导出的frozen_graph.pb文件对单张图片进行预测?此处使用一个编写的文件classify_image_inception_v3.py 脚本来完成这件事 。先来看这个脚本的使用方法:


  
  
  1. python classify_image_inception_v3.py \
  2. --model_path slim/satellite/frozen_graph.pb \
  3. --label_path data_prepare/pic/label.txt \
  4. --image_file test_image.jpg

预测结果如下,该图属于water的得分值最大。

water (score = 1.41468)
wood (score = 1.12560)
rock (score = 0.34318)
wetland (score = 0.31493
urban (score = -1.02338)

classify_image_inception_v3.py代码如下:


  
  
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import argparse
  19. import os.path
  20. import re
  21. import sys
  22. import tarfile
  23. import numpy as np
  24. from six.moves import urllib
  25. import tensorflow as tf
  26. FLAGS = None
  27. class NodeLookup(object):
  28. def __init__(self, label_lookup_path=None):
  29. self.node_lookup = self.load(label_lookup_path)
  30. def load(self, label_lookup_path):
  31. node_id_to_name = {}
  32. with open(label_lookup_path) as f:
  33. for index, line in enumerate(f):
  34. node_id_to_name[index] = line.strip()
  35. return node_id_to_name
  36. def id_to_string(self, node_id):
  37. if node_id not in self.node_lookup:
  38. return ''
  39. return self.node_lookup[node_id]
  40. def create_graph():
  41. """Creates a graph from saved GraphDef file and returns a saver."""
  42. # Creates graph from saved graph_def.pb.
  43. with tf.gfile.FastGFile(FLAGS.model_path, 'rb') as f:
  44. graph_def = tf.GraphDef()
  45. graph_def.ParseFromString(f.read())
  46. _ = tf.import_graph_def(graph_def, name= '')
  47. def preprocess_for_eval(image, height, width,
  48. central_fraction=0.875, scope=None):
  49. with tf.name_scope(scope, 'eval_image', [image, height, width]):
  50. if image.dtype != tf.float32:
  51. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  52. # Crop the central region of the image with an area containing 87.5% of
  53. # the original image.

猜你喜欢

转载自blog.csdn.net/weixin_38246633/article/details/87969268