机器学习实战-微调google vgg模型识别猫狗

一、读取数据,将数据转化为TFrecord格式

import tensorflow as tf 
import numpy as np 
import os
import cv2

#数据集所在的位置
DATASET_DIR='/home/zhou/Desktop/kaggle/train/'
TRAIN_OUTPUT_FILENAME='/home/zhou/Desktop/kaggle/train.tfrecords'
VAL_OUTPUT_FILENAME='/home/zhou/Desktop/kaggle/val.tfrecords'

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


#读取训练样本的图片序列以及对应的label
def get_filename_and_class(dataset_dir):
    photo_names = []
    class_names= []
    for filename in os.listdir(dataset_dir):
        path=dataset_dir+filename
        label=filename.split('.')[0]
        photo_names.append(path)
        class_names.append(label)
    return photo_names,class_names
#将label转化为0,1形式
def convert_class_names_to_label(class_names):
    label=[0 if names=='dog' else 1  for names in class_names]
    return label

#读取数据并将数据转化为tfrecords
def convert_data_to_tfrecords(photo_names,label,output_filename):
    with tf.python_io.TFRecordWriter(output_filename) as writer:
        for filename ,label in zip(photo_names,label):
            try:
                image_data=cv2.imread(filename)
                image_shape=image_data.shape
                image_data=image_data.tostring()
                features = {
                'filename': _bytes_feature(filename.encode('utf-8')),
                'rows': _int64_feature(image_shape[0]),
                'cols': _int64_feature(image_shape[1]),
                'channels': _int64_feature(image_shape[2]),
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=label_one_hot[label])),
                'image_data': _bytes_feature(image_data),
                }
                example = tf.train.Example(features=tf.train.Features(feature=features))
                writer.write(example.SerializeToString())
            except IOError as e:
                print("Could not read:",filename)
                print("Error:",e)
                print("Skip it\n")

#读取数据并将数据转化为tfrecords,这种方式生成的文件要远小于上一种方法
def small_convert_data_to_tfrecords(photo_names,label,output_filename):
    with tf.python_io.TFRecordWriter(output_filename) as writer:
        for filename ,label in zip(photo_names,label):
            try:
                image_shape=cv2.imread(filename).shape
                with tf.gfile.GFile(filename, 'rb') as fid:
                    image_data = fid.read()
                features = {
                'filename': _bytes_feature(filename.encode('utf-8')),
                'rows': _int64_feature(image_shape[0]),
                'cols': _int64_feature(image_shape[1]),
                'channels': _int64_feature(image_shape[2]),
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=label_one_hot[label])),
                'image_data': _bytes_feature(image_data),
                }
                example = tf.train.Example(features=tf.train.Features(feature=features))
                writer.write(example.SerializeToString())
            except IOError as e:
                print("Could not read:",filename)
                print("Error:",e)
                print("Skip it\n")


#此处还可以根据photo_names将数据分为两部分,80%的用于训练,20%用于交叉验证(调参用于获得最佳超参数)
#可以建立专门的函数生成这种dict格式
label_one_hot={'dog':[0,1],'cat':[1,0]}
photo_names,class_names=get_filename_and_class(DATASET_DIR)
length = len(photo_names)
#划分训练集和交叉验证集
train_photo_names=photo_names[:int(0.8*length)]
train_class_names=class_names[:int(0.8*length)]
val_photo_names=photo_names[int(0.8*length):]
val_class_names=class_names[int(0.8*length):]
small_convert_data_to_tfrecords(train_photo_names,train_class_names,TRAIN_OUTPUT_FILENAME)
small_convert_data_to_tfrecords(val_photo_names,val_class_names,VAL_OUTPUT_FILENAME)

一、从TFrecord文件读取数据,微调并训练Vgg模型

import tensorflow as tf 
import numpy as np 
import cv2
import tensorflow.contrib.slim.nets
import tensorflow.contrib.slim as slim
from tensorflow.contrib import layers
from tensorflow.contrib.framework.python.ops import arg_scope
from tensorflow.contrib.layers.python.layers import layers as layers_lib
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.framework import ops

SMALL_OUTPUT_FILENAME = '/home/zhou/Desktop/kaggle/small_train.tfrecords'
OUTPUT_FILENAME = '/home/zhou/Desktop/kaggle/train.tfrecords'
VAL_FILENAME = '/home/zhou/Desktop/kaggle/val.tfrecords'
MODEL_PATH = '/home/zhou/Desktop/kaggle/vgg_16.ckpt'
VGG_MEAN = [123.68, 116.78, 103.94]
learn_rating=0.0001
def extract_fn(tfrecords):
    features = {
            'filename': tf.FixedLenFeature([], tf.string),
            'rows': tf.FixedLenFeature([], tf.int64),
            'cols': tf.FixedLenFeature([], tf.int64),
            'channels': tf.FixedLenFeature([], tf.int64),
            'image_data': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([2], tf.int64)
    }
    sample = tf.parse_single_example(tfrecords, features)
    #对图片数据进行解码,对大文件操作建议先shuffle之后再去decode,不然shuffle的buffer会爆掉。
    #如果用OUTPUT_ILENAME,此处函数应该改为 tf.decode_raw,因为其读取方式为cv2.imread.而SMALL_OUTPUT_FILENAME读取方式为tf.gfile.GFile
    image = tf.image.decode_image(sample['image_data']) 
    img_shape = tf.stack([sample['rows'], sample['cols'], sample['channels']])
    image = tf.reshape(image,img_shape)
    #此处可以对图片进行各种处理,resize等操作,使用resize_images时除了ResizeMethod.NEAREST_NEIGHBOR方法外
    #其他方法都会生成带有小数的灰度值,在使用imshow等方法显示图片的时候只会显示小数部分,所以需要对数据进行a[0].astype(np.uint8)处理
    image = tf.image.resize_images(image, [224,224], method=tf.image.ResizeMethod.BILINEAR ,align_corners=True)
    #vgg模型并没有对图像进行normalization,In the code author subtract mean from each channel 其中VGG_MEAN = [123.68, 116.78, 103.94]
    means = tf.reshape(tf.constant(VGG_MEAN),[1,1,3])
    image = tf.subtract(image,means)
    #image = image - means
    #image = tf.image.resize_image_with_pad(image,224,224)
    label = sample['label']
    filename = sample['filename']
    return [image,label,filename]

def check_accuracy(sess, correct_prediction, is_training, dataset_init_op):
    """
    Check the accuracy of the model on either train or val (depending on dataset_init_op).
    """
    # Initialize the correct dataset
    sess.run(dataset_init_op)
    num_correct, num_samples = 0, 0
    while True:
        try:
            correct_pred = sess.run(correct_prediction, {is_training: False})
            num_correct += correct_pred.sum()
            num_samples += correct_pred.shape[0]
        except tf.errors.OutOfRangeError:
            break

    # Return the fraction of datapoints that were correctly classified
    acc = float(num_correct) / num_samples
    return acc

#tensorflow 官网上有预先训练好的模型参数,下载对应模型的checkpoint文件即可
#可以通过通过更改vgg_network网络结构来fune tuning,vgg训练的image size为224*224
#tf.contrib.layers.conv2d(inputs,num_outputs,kernel_size,stride=1,padding='SAME',data_format=None,rate=1,activation_fn=tf.nn.relu,
#normalizer_fn=None,normalizer_params=None,weights_initializer=initializers.xavier_initializer(),weights_regularizer=None,
#biases_initializer=tf.zeros_initializer(),biases_regularizer=None,reuse=None,variables_collections=None,outputs_collections=None,
#trainable=True,scope=None) conv2d会对wights,bias进行赋初始值,不需要我们去指定
def vgg_network(inputs,num_classes=1000,is_training=True,dropout_keep_prob=0.5,spatial_squeeze=True,scope='vgg_16'):
    with variable_scope.variable_scope(scope, 'vgg_16', [inputs]) as sc:
        end_points_collection = sc.original_name_scope + '_end_points'
    # Collect outputs for conv2d, fully_connected and max_pool2d.
        with arg_scope([layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d],outputs_collections=end_points_collection):
            net = layers_lib.repeat(inputs, 2, layers.conv2d, 64, [3, 3], scope='conv1')
            net = layers_lib.max_pool2d(net, [2, 2], scope='pool1')
            net = layers_lib.repeat(net, 2, layers.conv2d, 128, [3, 3], scope='conv2')
            net = layers_lib.max_pool2d(net, [2, 2], scope='pool2')
            net = layers_lib.repeat(net, 3, layers.conv2d, 256, [3, 3], scope='conv3')
            net = layers_lib.max_pool2d(net, [2, 2], scope='pool3')
            net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope='conv4')
            net = layers_lib.max_pool2d(net, [2, 2], scope='pool4')
            net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope='conv5')
            net = layers_lib.max_pool2d(net, [2, 2], scope='pool5')
            # Use conv2d instead of fully_connected layers.
            net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
            net = layers_lib.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout6')
            net = layers.conv2d(net, 4096, [1, 1], scope='fc7')
            net = layers_lib.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout7')
            net = layers.conv2d(net,num_classes, [1, 1],activation_fn=None,normalizer_fn=None,scope='fc8')
            # Convert end_points_collection into a end_point dict.
            end_points = utils.convert_collection_to_dict(end_points_collection)
            with tf.name_scope('wights'):
                wights = tf.get_default_graph().get_tensor_by_name("vgg_16/conv1/conv1_1/weights:0")
                print(wights)
            tf.summary.histogram('wights', wights)
            #print(end_points)
            if spatial_squeeze:
                net = array_ops.squeeze(net, [1, 2], name='fc8/squeezed')
                end_points[sc.name + '/fc8'] = net
            return net, end_points
def main(train_filename,val_filename,model_path,num_epochs,batch_size):
#map函数的作用是对数据dataset中的每个数据都应用函数extract_fn
#map( map_func,num_parallel_calls=None),Maps map_func across the elements of this dataset.
# This transformation applies map_func to each element of this dataset, and returns a new dataset
#  containing the transformed elements, in the same order as they appeared in the input.
    with tf.Graph().as_default() as g :
        is_training = tf.placeholder(tf.bool)
        #读取训练集中的数据
        train_dataset = tf.data.TFRecordDataset(train_filename)
        train_dataset = train_dataset.shuffle(buffer_size=26000)
        #train_dataset = train_dataset.repeat(num_epochs)
        train_dataset = train_dataset.map(extract_fn)
        train_dataset = train_dataset.batch(batch_size)
        #读取交叉验证集中的数据
        val_dataset = tf.data.TFRecordDataset(val_filename)
        val_dataset = val_dataset.map(extract_fn)
        val_dataset = val_dataset.batch(batch_size)
        #创造一个统一的迭代器,供不同的数据集使用,output_types为dataset的属性
        iterator = tf.data.Iterator.from_structure(train_dataset.output_types,train_dataset.output_shapes)
        #得到数据集中的元素
        image,label,_ = iterator.get_next()
        #对数据集进行初始化
        train_init_op = iterator.make_initializer(train_dataset)
        val_init_op = iterator.make_initializer(val_dataset)
        logits, _ = vgg_network(image, num_classes=2, is_training=is_training,dropout_keep_prob=0.5)
        #获得构建的网络结构中的tensor变量值,最后一层的tensor需要自己赋值,(因为自己的类别只有2个,而vgg默认的为1000个,所以不能从模型中restore)
        #最后一层的初始化可以调用sess.run(tf.global_variables_initializer())来实现
        variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc8'])
        fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8')
        #通过init_fn(sess),will load all the pretrained weights
        init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)
        label=tf.cast(label,tf.float32)
        #sparse_softmax_cross_entropy中的label不是one_hot形式,其为[0,num_class)int整数
        #tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
        with tf.name_scope('loss'):
            loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label,logits=logits))
        tf.summary.scalar('loss', loss)
        #只更新更改的最后一层的weights and bias
        fc8_train_steps = tf.train.AdamOptimizer(learn_rating).minimize(loss,var_list=fc8_variables)
        #更新整个网络结构的参数
        full_train_steps = tf.train.AdamOptimizer(learn_rating).minimize(loss)
        #计算神经网络计算出来的预测值最大值的下标
        correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        #print(variables_to_restore)
        #tf.get_default_graph().finalize()
     #Creates an Iterator for enumerating the elements of this dataset.
        with tf.Session() as sess:
            merged = tf.summary.merge_all()
            train_writer = tf.summary.FileWriter('/home/zhou/Desktop/kaggle',sess.graph)
            test_writer = tf.summary.FileWriter('/home/zhou/Desktop/kaggle')
            sess.run(tf.global_variables_initializer())
            init_fn(sess)  # load the pretrained weights
            #只更新修改指定层fc8的参数
            i = 0
            for epochs in range (num_epochs):
                sess.run(train_init_op)
                while True:
                    try:
                    #summary,_ = sess.run([merged,fc8_train_steps], {is_training: True})
                        summary,_ = sess.run([merged,full_train_steps], {is_training: True})
                        train_writer.add_summary(summary, i)
                    # if i % 10 == 0:
                    #   loss_result = sess.run(loss, {is_training: True})
                    #   print(loss_result)
                        i  = i + 1
                    except tf.errors.OutOfRangeError:
                        break
                train_acc = check_accuracy(sess,correct_prediction,is_training,train_init_op)
                val_acc = check_accuracy(sess,correct_prediction,is_training,val_init_op)
                print('Train accuracy: %f' % train_acc)
                print('Val accuracy: %f\n' % val_acc)

            #更新所有网络层参数,包括fc8层
            sess.run(train_init_op)
            while True:
                try:
                    _ = sess.run(full_train_steps, {is_training: True})
                except tf.errors.OutOfRangeError:
                    break
            train_acc = check_accuracy(sess,correct_prediction,is_training,train_init_op)
            val_acc = check_accuracy(sess,correct_prediction,is_training,val_init_op)
            print('Train accuracy: %f' % train_acc)
            print('Val accuracy: %f\n' % val_acc)
main(OUTPUT_FILENAME,VAL_FILENAME,MODEL_PATH,100,8)

猜你喜欢

转载自www.cnblogs.com/zyy-summary/p/10889041.html
今日推荐