tensorflow利用slim进行迁移学习

本文采用tensorflow的slim库进行迁移学习,网站为:github-slim

参考:TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)

源代码涉及了多个.py文件,对于初学者来说不便于阅读,对于不同的训练对象要修改的参数遍布较多,不太方便,因此这里将整个迁移学习分为三个.py,其中creat_tfrecord.py用于将样本转化为tensorflow的tfrecord格式;input_data.py用于读取生成的tfrecord格式数据并以队列的形式提供样本;finetune_mydata.py是主要的demo,其中调用有上述两个.py文件,要修改的一些参数都已经放在py文件的前端。

迁移学习主代码

根据自己的数据库要修改的参数已经放在了代码的最前端。

from nets import vgg
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import input_data
import os
import creat_tfrecord

slim = tf.contrib.slim
DATA_DIR = './datasets/data/Sample'#数据集的路径
NUM_CLASSES = 2  #输出类别
NUM_TRAIN = 3600 #训练集的总数
NUM_VAL = 1200  #验证集的总数
IMAGE_SIZE = vgg.vgg_16.default_image_size #获取图片大小


checkpoint_file = './model/vgg_16.ckpt' #官方下载的检查点文件路径
save_dir = './result/vgg16/fine_tune' #训练后模型保存路径
trained_file = 'ships_fine_tune.ckpt'  #训练后模型名称
log_dir = './logs/train'        #训练日志路径

#设置训练参数
batch_size  = 64
learning_rate = 0.0001
training_epochs = 10    #迭代轮数
display_epoch = 1       #每回合显示一次 
train_num_batch = int(np.ceil(NUM_TRAIN / batch_size)) #batch的数目
val_num_batch = int(np.ceil(NUM_VAL / batch_size)) #batch的数目
    

def Ships_fine_tuning():
    '''
    演示一个VGG16的例子 
    微调 这里只调整VGG16最后一层全连接层,把1000类改为5类 
    对网络进行训练
    '''
        
    '''
    1.设置参数,并加载数据
    '''
    if not tf.gfile.Exists(save_dir):
        tf.gfile.MakeDirs(save_dir)
        
    #调用creat_tfrecord.py对数据划分为训练集和验证集,分别生成TF格式数据
    creat_tfrecord.run(DATA_DIR, NUM_VAL)  
                         
    #生成batch   
    train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES,
                                                                       True, IMAGE_SIZE, IMAGE_SIZE)          
    test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES,
                                                                       False, IMAGE_SIZE, IMAGE_SIZE)          

    #获取模型参数的命名空间
    arg_scope = vgg.vgg_arg_scope()
    #arg_scope = resnet_v1.resnet_arg_scope()

    #创建网络
    with  slim.arg_scope(arg_scope):
        
        '''
        2.定义占位符和网络结构
        '''   
        
        #输入图片
        input_images = tf.placeholder(dtype=tf.float32,shape = [None,IMAGE_SIZE,IMAGE_SIZE,3])
            #图片标签
        input_labels = tf.placeholder(dtype=tf.float32,shape = [None,NUM_CLASSES])        
            #训练还是测试?测试的时候弃权参数会设置为1.0
        is_training = tf.placeholder(dtype = tf.bool)
        
        
        #创建vgg16网络  如果想冻结所有层,可以指定slim.conv2d中的 trainable=False
        logits,end_points =  vgg.vgg_16(input_images, is_training=is_training,num_classes = NUM_CLASSES)
        #print(end_points)  #每个元素都是以vgg_16/xx命名
            
        # Restore only the convolutional layers: 从检查点载入当前图除了fc8层之外所有变量的参数
        params = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
        #用于恢复模型  如果使用这个保存或者恢复的话,只会保存或者恢复指定的变量
        restorer = tf.train.Saver(params) 
        
        '''
        #从当前图中搜索指定scope的变量,然后从检查点文件中恢复这些变量(即vgg_16网络中定义的部分变量)  
        #如果指定了恢复检查点文件中不存在的变量,则会报错 如果不知道检查点文件有哪些变量,我们可以打印检查点文件查看变量名
        params = []
        conv1 = slim.get_variables(scope="vgg_16/conv1")
        params.extend(conv1)            
        conv2 = slim.get_variables(scope="vgg_16/conv2")
        params.extend(conv2)
        conv3 = slim.get_variables(scope="vgg_16/conv3")
        params.extend(conv3)
        conv4 = slim.get_variables(scope="vgg_16/conv4")
        params.extend(conv4)
        conv5 = slim.get_variables(scope="vgg_16/conv5")
        params.extend(conv5)
        fc6 = slim.get_variables(scope="vgg_16/fc6")
        params.extend(fc6)
        fc7 = slim.get_variables(scope="vgg_16/fc7")
        params.extend(fc7)          
        '''
        
        '''
        3 定义代价函数和优化器
        '''         
        #代价函数                    
        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=input_labels,logits=logits)) 
        loss_summary = tf.summary.scalar('loss',cost) 
        #设置优化器    
        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)    
         #求准确率
        pred = tf.argmax(logits,axis=1)         #预测标签     
        correct = tf.equal(pred,tf.argmax(input_labels,1))        #返回一个数组 表示统计预测正确或者错误 
        accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))    
        acc_summary = tf.summary.scalar('accuracy',accuracy)
                            
             
        #用于保存检查点文件 
        save = tf.train.Saver(max_to_keep=training_epochs) 
        
        #恢复模型
        with tf.Session() as sess: 
            merged = tf.summary.merge([loss_summary, acc_summary])  #合并
            train_writer = tf.summary.FileWriter(log_dir,sess.graph) #将训练日志写入到logs文件夹下
            sess.run(tf.global_variables_initializer())
                    
            #检查最近的检查点文件
            ckpt = tf.train.latest_checkpoint(save_dir)
            if ckpt != None:
                save.restore(sess,ckpt)
                print('从上次训练保存后的模型继续训练!')
            else:
                restorer.restore(sess, checkpoint_file)                
                print('从官方模型加载训练!')

                                   
            coord = tf.train.Coordinator()  #创建一个协调器,管理线程
            threads = tf.train.start_queue_runners(sess=sess,coord=coord)  #启动QueueRunner, 此时文件名才开始进队。                    
            
            '''
            #4 查看预处理之后的图片
            imgs, labs = sess.run([train_images, train_labels])                  
            print('原始训练图片信息:',imgs.shape,labs.shape)
            show_img = np.array(imgs[0],dtype=np.uint8)
            plt.imshow(show_img)                                
            plt.title('Original train image')   
            plt.show()

                                    
            imgs, labs = sess.run([test_images, test_labels])                  
            print('原始测试图片信息:',imgs.shape,labs.shape)
            show_img = np.array(imgs[0],dtype=np.uint8)
            plt.imshow(show_img)                                
            plt.title('Original test image')   
            plt.show()
            '''                        
            print('开始训练!')
            for epoch in range(training_epochs):                
                train_loss = 0.0  
                for i in range(train_num_batch):
                    imgs, labs, = sess.run([train_images, train_labels])                                                
                    _, loss, train_summaries = sess.run([optimizer, cost, merged],feed_dict={input_images:imgs,input_labels:labs,is_training:True})   
                    train_writer.add_summary(train_summaries, (i+1)+epoch*batch_size)  #将每一个batch的训练结果保存至日志文件   
                    train_loss += loss 

                #打印信息
                if epoch % display_epoch == 0:          
                    train_accuracy = sess.run(accuracy,feed_dict={input_images:imgs,input_labels:labs,is_training:False})                   
                    print('Epoch {}/{}  average cost {:.9f}  train accuracy {:.2f}'.format(epoch+1, training_epochs, train_loss/train_num_batch, train_accuracy))
                    
                #进行测试
                val_accuracy = 0.0
                val_loss = 0.0
                for j in range(val_num_batch):
                    imgs, labs = sess.run([test_images, test_labels])                                                                    
                    cost_values,accuracy_values = sess.run([cost, accuracy],feed_dict = {input_images:imgs,input_labels:labs,is_training:False})
                    val_accuracy  += accuracy_values
                    val_loss  += cost_values
                
                print('Epoch {}/{}  Test cost {:.9f} Test accuracy {:.2f}'.format(epoch+1,training_epochs,val_loss/val_num_batch, val_accuracy/val_num_batch))
                
                #保存模型
                save.save(sess,os.path.join(save_dir,trained_file),global_step = epoch)
                print('Epoch {}/{}  模型保存成功'.format(epoch+1,training_epochs))
                
            print('训练完成')
                    
            #终止线程
            coord.request_stop()
            coord.join(threads)  
           
if __name__ == '__main__':
    tf.reset_default_graph()     #移除存在的图
    Ships_fine_tuning()

读取样本生成tfrecord文件

当你的样本数量过大时,要调高_NUM_SHARDS

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os
import random
import sys
import pdb

import tensorflow as tf

from datasets import dataset_utils


# Seed for repeatability.
_RANDOM_SEED = 0

# The number of shards per dataset split.
# 如果你的样本数量过大,要调高该参数
_NUM_SHARDS = 2


class ImageReader(object):
  """Helper class that provides TensorFlow image coding utilities."""

  def __init__(self):
    # Initializes function that decodes RGB JPEG data.
    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

  def read_image_dims(self, sess, image_data):
    image = self.decode_jpeg(sess, image_data)
    return image.shape[0], image.shape[1]

  def decode_jpeg(self, sess, image_data):
    image = sess.run(self._decode_jpeg,
                     feed_dict={self._decode_jpeg_data: image_data})
    #pdb.set_trace()
    assert len(image.shape) == 3
    assert image.shape[2] == 3
    return image


def _get_filenames_and_classes(dataset_dir):
  """Returns a list of filenames and inferred class names.

  Args:
    dataset_dir:包括多个子文件夹,每一个子文件夹是一个类,以类名命名,
    其中存放该类的样本.

  Returns:
    A list of image file paths, relative to `dataset_dir` and the list of
    subdirectories, representing class names.
  """
  directories = []
  class_names = []
  for filename in os.listdir(dataset_dir):
    path = os.path.join(dataset_dir, filename)
    if os.path.isdir(path):
      directories.append(path)  #directories里面是每一类文件夹路径
      class_names.append(filename)
      
  photo_filenames = []
  for directory in directories:
    for filename in os.listdir(directory):
      path = os.path.join(directory, filename)    
      photo_filenames.append(path)   

  return photo_filenames, sorted(class_names)


def _get_dataset_filename(dataset_dir, split_name, shard_id):
  output_filename = 'myimage_%s_%05d-of-%05d.tfrecord' % (
      split_name, shard_id, _NUM_SHARDS)
  return os.path.join(dataset_dir, output_filename)


def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
  """Converts the given filenames to a TFRecord dataset.

  Args:
    split_name: The name of the dataset, either 'train' or 'validation'.
    filenames: A list of absolute paths to png or jpg images.
    class_names_to_ids: A dictionary from class names (strings) to ids
      (integers).
    dataset_dir: The directory where the converted datasets are stored.
  """
  assert split_name in ['train', 'validation']

  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))

  with tf.Graph().as_default():
    image_reader = ImageReader()

    with tf.Session('') as sess:

      for shard_id in range(_NUM_SHARDS):
        output_filename = _get_dataset_filename(
            dataset_dir, split_name, shard_id)

        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
          for i in range(start_ndx, end_ndx):
            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                i+1, len(filenames), shard_id))
            sys.stdout.flush()

            # Read the filename:
            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
            print(filenames[i])
            height, width = image_reader.read_image_dims(sess, image_data)

            class_name = os.path.basename(os.path.dirname(filenames[i]))
            class_id = class_names_to_ids[class_name]

            example = dataset_utils.image_to_tfexample(
                image_data, b'jpg', height, width, class_id)
            tfrecord_writer.write(example.SerializeToString())

  sys.stdout.write('\n')
  sys.stdout.flush()


def _dataset_exists(dataset_dir):
  for split_name in ['train', 'validation']:
    for shard_id in range(_NUM_SHARDS):
      output_filename = _get_dataset_filename(
          dataset_dir, split_name, shard_id)
      if not tf.gfile.Exists(output_filename):
        return False
  return True


def run(dataset_dir,_NUM_VALIDATION):
  """
  读取样本,划分为训练集和验证集并转换为TF格式.
  Args:
    dataset_dir: The dataset directory where the dataset is stored.
  """
  if not tf.gfile.Exists(dataset_dir):
    tf.gfile.MakeDirs(dataset_dir)

  if _dataset_exists(dataset_dir):
    print('Dataset files already exist. Exiting without re-creating them.')
    return

  photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
  class_names_to_ids = dict(zip(class_names, range(len(class_names))))

  # Divide into train and test:
  random.seed(_RANDOM_SEED)
  random.shuffle(photo_filenames)
  training_filenames = photo_filenames[_NUM_VALIDATION:]
  validation_filenames = photo_filenames[:_NUM_VALIDATION]

  # First, convert the training and validation sets.
  _convert_dataset('train', training_filenames, class_names_to_ids,
                   dataset_dir)
  _convert_dataset('validation', validation_filenames, class_names_to_ids,
                   dataset_dir)

  # Finally, write the labels file:
  labels_to_class_names = dict(zip(range(len(class_names)), class_names))
  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

  print('\nFinished converting the dataset!')

读取tfrecord文件并以队列形式供给

要调整的参数已放在最前端。

另外,在对数据进行预处理时,我选用的是tensorflow默认的图片放缩函数以及标准化函数,用这两个函数处理过的训练结果较好,见后面。如果采用slim库的预处理函数反而训练结果不好

import tensorflow as tf
import os
from preprocessing import vgg_preprocessing
from datasets import dataset_utils

slim = tf.contrib.slim

_FILE_PATTERN = 'myimage_%s_*.tfrecord'  #
SPLITS_TO_SIZES = {'train': 3600, 'validation': 1200} #修改为你的数据库的样本大小,这里3600代表3600张样本
_NUM_CLASSES = 2  #修改为你的类别数

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and 4',
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  """Gets a dataset tuple with instructions for reading flowers.

  Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

  Returns:
    A `Dataset` namedtuple.

  Raises:
    ValueError: if `split_name` is not a valid train/validation split.
  """
  if split_name not in SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  if dataset_utils.has_labels(dataset_dir):
    labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)


def read_image_and_label(dataset_dir,is_training=False):
    '''  
    读取tf格式数据             
    args:
        dataset_dir:数据集所在的目录
        is_training:设置为TRue,表示加载训练数据集,否则加载验证集
    return:
        image,label:返回随机读取的一张图片,和对应的标签
    '''    
    #选择数据集train
    if is_training:        
        dataset = get_split(split_name = 'train',dataset_dir=dataset_dir)
    else:
        dataset = get_split(split_name = 'validation',dataset_dir=dataset_dir)
    
    #创建一个数据provider
    provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    
    #通过provider的get随机获取一条样本数据 返回的是两个张量
    [image,label] = provider.get(['image','label'])

    return image,label



def get_batch_images_and_label(dataset_dir,batch_size,num_classes,is_training=False,output_height=224, output_width=224,num_threads=10):
    '''
    每次取出batch_size个样本
    
    注意:这里预处理调用的是slim库图片预处理的函数,例如:如果你使用的vgg网络,就调用vgg网络的图像预处理函数
          如果你使用的是自己定义的网络,则可以自己写适合自己图像的预处理函数,比如归一化处理也可以使用其他网络已经写好的预处理函数
    
    args:
         dataset_dir:数据集所在的目录
         batch_size:一次取出的样本数量
         num_classes:输出的类别 用于对标签one_hot编码
         is_training:设置为TRue,表示加载训练数据集,否则加载验证集
         output_height:输出图片高度
         output_width:输出图片宽
         
     return:
        images,labels:返回随机读取的batch_size张图片,和对应的标签one_hot编码
    '''
    #获取单张图像和标签
    image,label = read_image_and_label(dataset_dir, is_training)   
    
    #image = vgg_preprocessing.preprocess_image(image, output_height, output_width,is_training=is_training)
    #这里没有采用slim库图片预处理的函数,而采用tensorflow原始的放缩的方式以及标准化方式进行预处理,method=0为双线性插值
    crop_image = tf.image.resize_images(image, [output_width,output_height],method=0)  
    image = tf.image.per_image_standardization(crop_image)   # 标准化数据
    
    #缩放处理
    #image = tf.image.convert_image_dtype(image, dtype=tf.float32)  
    #image = tf.image.resize_image_with_crop_or_pad(image, output_height, output_width)
    
    #  shuffle_batch 函数会将数据顺序打乱
    #  bacth 函数不会将数据顺序打乱    
    images, labels = tf.train.batch(
                [image, label],
                batch_size = batch_size,
                capacity=5 * batch_size, 
                num_threads = num_threads)    
        
    #one-hot编码
    labels = slim.one_hot_encoding(labels,num_classes)
    
    return images,labels

运行

这里我训练10个epoch,结果如下,可以说训练集的准确率达到了1,验证集的准确率也已达到95%,取得了不错的结果。

最后用tensorboard可以查看训练的loss和accuracy图,代码如下

tensorboard --logdir=logs/train

测试单张图片

def test_on_image_tf():
    '''
    使用微调好的网络测试单张图片(原始tensorflow形式)
    '''          
    TEST_DIR =  './test/6.jpg'     #数据路径                   
    org_image = tf.image.decode_jpeg(tf.read_file(TEST_DIR), channels=3) #加载数据  
    crop_image = tf.image.resize_images(org_image, [IMAGE_SIZE, IMAGE_SIZE],method=0)
    image = tf.image.per_image_standardization(crop_image)   # 标准化数据
    image = tf.reshape(image,[1,IMAGE_SIZE, IMAGE_SIZE, 3])  #reshape以满足输入要求
    
      #原始tensotflow,需要占位符
    input_images = tf.placeholder(dtype=tf.float32,shape = [None,IMAGE_SIZE,IMAGE_SIZE,3])    #输入图片       
    is_training = tf.placeholder(dtype = tf.bool)    #训练还是测试?测试的时候弃权参数会设置为1.0
    
    #获取模型参数的命名空间
    arg_scope = vgg.vgg_arg_scope()

    #创建网络
    with  slim.arg_scope(arg_scope):
        logits,end_points =  vgg.vgg_16(input_images, is_training=is_training,num_classes = NUM_CLASSES)  

    #预测标签
    pred = tf.argmax(logits,axis=1)             
        
    restorer = tf.train.Saver() 
    #恢复模型
    with tf.Session() as sess:      
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.latest_checkpoint(save_dir)
        if ckpt != None:
            #恢复模型
            restorer.restore(sess,ckpt)
            print("Model restored.")
                       
        org_imgs, standard_imgs = sess.run([org_image, image])                                        
        pred_value = sess.run(pred, feed_dict = {input_images:standard_imgs,is_training:False})
        
        plt.imshow(org_imgs)                                
        plt.title('Original test image')   
        plt.show()
        if pred_value == 0:   
            print('预测结果为:非船舶')
        else:
            print('预测结果为:船舶')
def test_on_image_slim():
    '''
    使用微调好的网络测试单张图片,利用slim,无需占位符送入数据,更简便
    '''          
    TEST_DIR =  './test/6.jpg'     #数据路径                   
    org_image = tf.image.decode_jpeg(tf.read_file(TEST_DIR), channels=3) #加载数据  
    crop_image = tf.image.resize_images(org_image, [IMAGE_SIZE, IMAGE_SIZE],method=0)
    image = tf.image.per_image_standardization(crop_image)   # 标准化数据
    image = tf.reshape(image,[1,IMAGE_SIZE, IMAGE_SIZE, 3])  #reshape以满足输入要求
    

    #获取模型参数的命名空间
    arg_scope = vgg.vgg_arg_scope()

    #创建网络并送入数据
    with  slim.arg_scope(arg_scope):
        logits,end_points =  vgg.vgg_16(image, is_training=False, num_classes = NUM_CLASSES)
        
    #预测标签        
    pred = tf.argmax(logits,axis=1)       
        
    restorer = tf.train.Saver()  
    #恢复模型
    with tf.Session() as sess:      
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.latest_checkpoint(save_dir)
        if ckpt != None:
            #恢复模型
            restorer.restore(sess,ckpt)
            print("Model restored.")
                       
        org_imgs, logit, pred_value= sess.run([org_image, pred])                                           
        plt.imshow(org_imgs)                                
        plt.title('Original test image')   
        plt.show()
        if pred_value == 0:   
            print('预测结果为:非船舶')
        else:
            print('预测结果为:船舶')

上述两种代码皆可

猜你喜欢

转载自blog.csdn.net/Mr_health/article/details/81285119