tensorflow入门笔记(二十)使用slim模型库训练自己的数据

1、概述

上一节,我们使用python3爬取了百度图片的一些图片数据,这一节,我们就使用这些爬取下来的图片,训练我们自己的模型,用来识别猪、蛇、狗、大象、老虎这五种动物。在这里吐嘲一下百度图片搜索结果,真是不敢恭维,前面几页结果相关度还好,越到后面越不准确。所以,要对上一节中下载的图片数据进行筛选,将一些无关的图片删除。

2、将数据集转成TFRecord格式

2.1、模仿slimFlowers数据集

学习的一大技能就是从模仿开始。第十六讲《tensorflow入门笔记(十六)使用slim模型库对图片分类》讲到,Flowers数据集共有2500张训练图片和2500张测试图片,共5个分类,分别是菊(daisy)、蒲公英(dandelion)、玫瑰(roses)、向日葵(sunflowers)、郁金香(tulips)。原始数据集目录如下图,

每个子文件夹下是对应品种的图片,如daisy


使用

python download_and_convert_data.py \

 --dataset_name=flowers --dataset_dir=images_data/flowers

命令将数据集下载并转成TFRecord格式,所以,我们就从这作为入口,开始模仿。

2.2、整理数据

打开download_and_convert_data.py文件,看到以下代码,

可知,flowers数据集运行的是,download_and_convert_flowers.run(FLAGS.dataset_dir)

打开datasets/download_and_convert_flowers.py文件,看到数据集的URL


搜索一下变量_DATA_URL

这里会将这个下载来的数据集删除了?搜索一下_clean_up_temporary_files

函数,


可以看到,处理完数据以后才将原始数据集文件删除的,那么,将

_clean_up_temporary_files(dataset_dir)

注释掉,再运行,我们应该就可以得到原始的数据集了。运行结果如下,

对比第十六讲得到的结果,


跟我们设想的一样,那么,进入flower_photos文件夹看看,

里面有5个文件夹,每个文件夹下有相应的原始图片,


好了,那么我们也模仿它这样排列。

首先回顾一下上一讲我们下载的数据的排列方法,

五个文件下,每个文件下代表一个种类,用相应的英文名做文件名,以Dog文件夹为例,其下的每个子文件夹包含50张相应的图片,总共有40个子文件夹,我们对图片进行筛选以后,每个子文件夹下就不一定有50张图片了,


现在,我们将Dog的子文件夹下的图片剪切到Dog目录下,然后将所有的子文件夹删除。剪切命令如下,

cd Dog

find Dog* -name "*.jpg*" -exec mv {} . \;

其他几个以此类推,得到结构如下,


最后,将它们拷贝到images_data/animals/animals_photos目录下。

2.3、将图片转成TFRecord格式

数据整理好以后,就模仿代码咯,复制download_and_convert_flowers.py并将文件名改为convert_animals.py。修改后的源码如下,

#encoding:utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os
import random
import sys

import tensorflow as tf

from datasets import dataset_utils

# The URL where the Flowers data can be downloaded.
_DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

# The number of images in the validation set.
_NUM_VALIDATION = 350

# Seed for repeatability.
_RANDOM_SEED = 0

# The number of shards per dataset split.
_NUM_SHARDS = 5


class ImageReader(object):
  """Helper class that provides TensorFlow image coding utilities."""

  def __init__(self):
    # Initializes function that decodes RGB JPEG data.
    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

  def read_image_dims(self, sess, image_data):
    image = self.decode_jpeg(sess, image_data)
    return image.shape[0], image.shape[1]

  def decode_jpeg(self, sess, image_data):
    image = sess.run(self._decode_jpeg,
                     feed_dict={self._decode_jpeg_data: image_data})
    assert len(image.shape) == 3
    assert image.shape[2] == 3
    return image


def _get_filenames_and_classes(dataset_dir):
  """Returns a list of filenames and inferred class names.

  Args:
    dataset_dir: A directory containing a set of subdirectories representing
      class names. Each subdirectory should contain PNG or JPG encoded images.

  Returns:
    A list of image file paths, relative to `dataset_dir` and the list of
    subdirectories, representing class names.
  """
  #将flower_photos改为animals_photos
  flower_root = os.path.join(dataset_dir, 'animals_photos')
  directories = []
  class_names = []
  for filename in os.listdir(flower_root):
    path = os.path.join(flower_root, filename)
    if os.path.isdir(path):
      directories.append(path)
      class_names.append(filename)

  photo_filenames = []
  for directory in directories:
    for filename in os.listdir(directory):
      path = os.path.join(directory, filename)
      photo_filenames.append(path)

  return photo_filenames, sorted(class_names)


def _get_dataset_filename(dataset_dir, split_name, shard_id):
  #修改文件名,将flowersg改为animals
  output_filename = 'animals_%s_%05d-of-%05d.tfrecord' % (
      split_name, shard_id, _NUM_SHARDS)
  return os.path.join(dataset_dir, output_filename)


def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
  """Converts the given filenames to a TFRecord dataset.

  Args:
    split_name: The name of the dataset, either 'train' or 'validation'.
    filenames: A list of absolute paths to png or jpg images.
    class_names_to_ids: A dictionary from class names (strings) to ids
      (integers).
    dataset_dir: The directory where the converted datasets are stored.
  """
  assert split_name in ['train', 'validation']

  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))

  with tf.Graph().as_default():
    image_reader = ImageReader()

    with tf.Session('') as sess:

      for shard_id in range(_NUM_SHARDS):
        output_filename = _get_dataset_filename(
            dataset_dir, split_name, shard_id)

        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
          for i in range(start_ndx, end_ndx):
            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                i+1, len(filenames), shard_id))
            sys.stdout.flush()

            # Read the filename:
            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
            height, width = image_reader.read_image_dims(sess, image_data)

            class_name = os.path.basename(os.path.dirname(filenames[i]))
            class_id = class_names_to_ids[class_name]

            example = dataset_utils.image_to_tfexample(
                image_data, b'jpg', height, width, class_id)
            tfrecord_writer.write(example.SerializeToString())

  sys.stdout.write('\n')
  sys.stdout.flush()


def _clean_up_temporary_files(dataset_dir):
  """Removes temporary files used to create the dataset.

  Args:
    dataset_dir: The directory where the temporary files are stored.
  """
  filename = _DATA_URL.split('/')[-1]
  filepath = os.path.join(dataset_dir, filename)
  tf.gfile.Remove(filepath)
  
  # 将flower_photos改为animals_photos
  tmp_dir = os.path.join(dataset_dir, 'animals_photos')
  tf.gfile.DeleteRecursively(tmp_dir)


def _dataset_exists(dataset_dir):
  for split_name in ['train', 'validation']:
    for shard_id in range(_NUM_SHARDS):
      output_filename = _get_dataset_filename(
          dataset_dir, split_name, shard_id)
      if not tf.gfile.Exists(output_filename):
        return False
  return True


def run(dataset_dir):
  """Runs the download and conversion operation.

  Args:
    dataset_dir: The dataset directory where the dataset is stored.
  """
  if not tf.gfile.Exists(dataset_dir):
    tf.gfile.MakeDirs(dataset_dir)

  if _dataset_exists(dataset_dir):
    print('Dataset files already exist. Exiting without re-creating them.')
    return

  #因为我们不需要下载,所以这行注释掉
  # dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
  photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
  class_names_to_ids = dict(zip(class_names, range(len(class_names))))

  # Divide into train and test:
  random.seed(_RANDOM_SEED)
  random.shuffle(photo_filenames)
  training_filenames = photo_filenames[_NUM_VALIDATION:]
  validation_filenames = photo_filenames[:_NUM_VALIDATION]

  # First, convert the training and validation sets.
  _convert_dataset('train', training_filenames, class_names_to_ids,
                   dataset_dir)
  _convert_dataset('validation', validation_filenames, class_names_to_ids,
                   dataset_dir)

  # Finally, write the labels file:
  labels_to_class_names = dict(zip(range(len(class_names)), class_names))
  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

  #将这行注释掉,要不然转换完以后,原始数据会被删除
  # _clean_up_temporary_files(dataset_dir)
  print('\nFinished converting the Flowers dataset!')

再修改download_and_convert_data.py,添加

from datasets import convert_animals

再在

elif FLAGS.dataset_name == 'mnist':
  download_and_convert_mnist.run(FLAGS.dataset_dir)

后添加

elif FLAGS.dataset_name == 'animals':
    convert_animals.run(FLAGS.dataset_dir)

如下图所示,


然后运行命令,

python download_and_convert_data.py \

--dataset_name=animals --dataset_dir=images_data/animals

运行结果,

查看目录images_data/animals/

似乎成功了,已经生成tfrecord文件,接着验证一下,使用第十六讲的方法,从tfrecord文件获取一张图片并显示看看。先来回顾一下第十六讲读取tfrecord的代码,

#encoding:utf-8
from datasets import flowers
import tensorflow as tf
import pylab
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
slim = tf.contrib.slim

#flowers数据集目录
DATA_DIR = 'images_data/flowers/'

#指定获取“validation”下的数据
dataset = flowers.get_split('validation', DATA_DIR)

# Creates a TF-Slim DataProvider which reads the dataset in the background
# during both training and testing.
provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
[image, label] = provider.get(['image', 'label'])

#在session下读取数据,并用pylab显示图片
with tf.Session() as sess:
    #初始化变量
    sess.run(tf.global_variables_initializer())
    #启动队列
    tf.train.start_queue_runners()
    image_batch,label_batch = sess.run([image, label])
    #显示图片
    pylab.imshow(image_batch)
    pylab.show()

从代码中看到,需要导入from datasets import flowers

显然我们不能直接用这个,所以,继续模仿。

3、定义datasets文件

将datasets/flowers.py复制并重命名为animals.py ,将

_FILE_PATTERN = 'flowers_%s_*.tfrecord'

改为

_FILE_PATTERN = 'animals_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train'3320'validation'350}

改为

SPLITS_TO_SIZES = {'train'4383'validation'350}

其中,train代表训练的图片张数,validation代表验证使用的图片张数。我们爬取的图片整理以后,只剩下4733张有效图片,验证用了350张,所以训练只有4383张。

完整代码如下,

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.

The dataset scripts used to create the dataset can be found at:
tensorflow/models/research/slim/datasets/convert_animals.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tensorflow as tf

from datasets import dataset_utils

slim = tf.contrib.slim

_FILE_PATTERN = 'animals_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 4383, 'validation': 350}

_NUM_CLASSES = 5

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and 4',
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  """Gets a dataset tuple with instructions for reading flowers.

  Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

  Returns:
    A `Dataset` namedtuple.

  Raises:
    ValueError: if `split_name` is not a valid train/validation split.
  """
  if split_name not in SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  if dataset_utils.has_labels(dataset_dir):
    labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)

所以,将读取tfrecord的代码的改为以下代码,

#encoding:utf-8
from datasets import animals
import tensorflow as tf
import pylab
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
slim = tf.contrib.slim

#flowers数据集目录
DATA_DIR = 'images_data/animals/'

#指定获取“validation”下的数据
dataset = animals.get_split('validation', DATA_DIR)

# Creates a TF-Slim DataProvider which reads the dataset in the background
# during both training and testing.
provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
[image, label] = provider.get(['image', 'label'])

#在session下读取数据,并用pylab显示图片
with tf.Session() as sess:
    #初始化变量
    sess.run(tf.global_variables_initializer())
    #启动队列
    tf.train.start_queue_runners()
    image_batch,label_batch = sess.run([image, label])
    #显示图片
    pylab.imshow(image_batch)
    pylab.show()

运行得到结果为:


说明,我们的TFRecord格式转换正确。

4、开始训练

接下来还要改哪里呢?不好意思,我也不知道,那么,试着运行训练的命令看看咯?执行以下命令看看有什么提示再说。

python train_image_classifier.py \

  --train_dir=saver/inv3_animals \

  --dataset_name=animals \

  --dataset_split_name=train \

  --dataset_dir=images_data/animals/ \

  --model_name=inception_v3 \

  --batch_size=5 \

  --learning_rate=0.0001 \

  --learning_rate_decay_type=fixed \

  --save_interval_secs=60 \

  --save_summaries_secs=60 \

  --log_every_n_steps=10 \

  --optimizer=rmsprop \

  --weight_decay=0.00004

运行结果,

很不幸,没有通过,看提示。datasets/dataset_factory.py文件下的get_dataset函数,看看做了什么?


如果name不在datasets_map里就打印错误,那么看看这个datasets_map里是什么?

似乎就是我们数据集的名称,我们将animals加上去试试,改成

编辑工具提示animals有错误,网上搜索flowers看看,


说明这里还需要导包,

所以,改为,


再运行看看,运行结果,

可以看到,我们的模型就这样开始训练了,很简单啊~~~

我靠,得意得太早了,马上出错,来看看报什么错。

INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>, image_size must contain 3 elements[4]

图片的长度必须包含3个元素?那就是有可能有些图片不是RGB格式咯。所以,先找到这些不是RGB格式的图片,用其他图片替换掉它们,或者删除它们,或者将它们转成RGB格式。

作为程序猿,当然不可能手动去找啦,查找非RGB格式图片代码如下,

#encoding:utf-8
from PIL import Image
import os

def get_not_rgb_images(rootdir):
    list = os.listdir(rootdir)
    for i in range(0, len(list)):
        filename = os.path.join(rootdir, list[i])
        # print(filename)
        if os.path.isfile(filename):
            img = Image.open(filename)
            pixels = img.getpixel((0, 0))

            if type(pixels) == int:
                print('单通道:' + filename)
            elif type(pixels) == tuple:
                if  len(pixels) != 3:
                    print('非RGB的多通道:' +filename)
        else:
            get_not_rgb_images(filename)



if __name__ == '__main__':
    rootdir = 'images_data/animals/animals_photos/'
    get_not_rgb_images(rootdir)

运行结果:


只有一个文件,那随便下载一张图片将其替换掉好了。替换以后,再运行一下代码检查一下是否还有不是RGB格式的图片,没有的话,删除原来的TFRecord文件,再运行下面命令重新将数据集转成TFRecord格式。

python download_and_convert_data.py --dataset_name=animals --dataset_dir=images_data/animals

再运行训练的命令,

python train_image_classifier.py \

  --train_dir=saver/inv3_animals \

  --dataset_name=animals \

  --dataset_split_name=train \

  --dataset_dir=images_data/animals/ \

  --model_name=inception_v3 \

  --batch_size=5 \

  --learning_rate=0.0001 \

  --learning_rate_decay_type=fixed \

  --save_interval_secs=60 \

  --save_summaries_secs=60 \

  --log_every_n_steps=10 \

  --optimizer=rmsprop \

  --weight_decay=0.00004

去洗个澡回来再看看有没有出错~

好了,洗澡回来,程序没出错,

5、第十六讲训练结果

这里插讲一下第十六讲的训练结果,地十六讲中,我们只训练几百步,所以得到的准确率并不高,后来我又加大训练次数,从头开始训练的模型,我训练了694490步,得到的准确率如下,


微调的模型我训练了542700步,准确率如下,

总结:

TFRecord的具体格式怎么样,它怎么存储数据的,我们现在处于初学阶段,并不需要去深究,等以后需要去了解的时候再去学,现在要做的,是大概将一些tensorflow主流的示例都学一遍。下一讲,我们继续往下学习,学习深度学习中的目标检测。



猜你喜欢

转载自blog.csdn.net/rookie_wei/article/details/80796009