tensorflow随笔——VGG网络

这次用slim搭个稍微大一点的网络VGG16,VGG16和VGG19实际上差不多,所以本例程的代码以VGG16来做5类花的分类任务。

VGG网络相比之前的LeNet,AlexNet引入如下几个特点:

1. 堆叠3×3的小卷积核替代了5×5,7×7的大卷积核。

虽然5×5的卷积核感受野大,但是参数多。2个3×3的卷积堆叠感受野等同于5×5,并且进行了2次非线性变换。总结一下:相比于大卷积核,小卷积核的堆叠一方面减少了参数; 另一方面进行了更多的非线性映射,增加了网络表达能力。

2.网络层数加深。我们先不谈深层网络难以训练又或者梯度弥散等缺点,在特征的抽象化或者网络的表达能力范畴上,深层网络比浅层网络更加能够拟合数据的分布。

3.VGG网络的原作还引入了数据增广,图像预处理等trick。

开始贴代码阶段,工程分为三个文件:

vgg.py: 搭建16层的VGG网络。

import tensorflow as tf
import tensorflow.contrib.slim as slim


def build_vgg(rgb, num_classes, keep_prob, train=True):
    with slim.arg_scope([slim.conv2d, slim.fully_connected], activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm):
        # block_1
        net = slim.repeat(rgb, 2, slim.conv2d, 64, [3, 3], padding='SAME', scope='conv1')
        net = slim.max_pool2d(net, [2, 2], scope='pool1')

        # block_2
        net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], padding='SAME', scope='conv2')
        net = slim.max_pool2d(net, [2, 2], scope='pool2')

        # block_3
        net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], padding='SAME', scope='conv3')
        net = slim.max_pool2d(net, [2, 2], scope='pool3')

        # block_4
        net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], padding='SAME', scope='conv4')
        net = slim.max_pool2d(net, [2, 2], scope='pool4')

        # block_5
        net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], padding='SAME', scope='conv5')
        net = slim.max_pool2d(net, [2, 2], scope='pool5')

        # flatten
        feature_shape = net.get_shape()
        flattened_shape = feature_shape[1].value * feature_shape[2].value * feature_shape[3].value
        pool5_flatten = tf.reshape(net, [-1, flattened_shape])

        # fc6
        net = slim.fully_connected(pool5_flatten, 4096, scope='fc6')
        if train:
            net = slim.dropout(net, keep_prob=keep_prob, scope='dropout6')

        # fc7
        net = slim.fully_connected(net, 4096, scope='fc7')
        if train:
            net = slim.dropout(net, keep_prob=keep_prob, scope='dropout7')

        # fc8
        net = slim.fully_connected(net, num_classes, activation_fn=tf.nn.softmax, scope='fc8')
    return net

tfrecords.py:用于数据的编码和解码,本例程不同与之前的文章采用feed_dict向网络喂数据,而是使用tensorflow自己的TFRecord结构编码数据集。

import tensorflow as tf
import numpy as np
import os
import glob
from PIL import Image

path_tfrecord = '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/'

def convert_to_tfrecord(images, labels, filename):
    print("Converting data into %s ..." % filename)
    writer = tf.python_io.TFRecordWriter(path_tfrecord + filename)
    for index, img in enumerate(images):
        img_raw = Image.open(img)
        if img_raw.mode != "RGB":
            continue
        img_raw = img_raw.resize((256, 256))
        img_raw = img_raw.tobytes()
        label = int(labels[index])
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        }))
        writer.write(example.SerializeToString())
    writer.close()

def read_and_decode(filename, is_train=None):
    filename_queue = tf.train.string_input_producer([filename], num_epochs=400)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [256, 256, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5

    if is_train == True:
        img = tf.random_crop(img, [224, 224, 3])
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_brightness(img, max_delta=63)
        img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
        img = tf.image.per_image_standardization(img)
    else:
        img = tf.image.resize_image_with_crop_or_pad(img, 224, 224)
        img = tf.image.per_image_standardization(img)

    label = tf.cast(features['label'], tf.int32)
    return img, label

def get_file(path):
    cate = [path+x for x in os.listdir(path) if os.path.isdir(path+x)]
    images = []
    labels = []
    for idx, folder in enumerate(cate):
        for img in glob.glob(folder+'/*.jpg'):
            print('reading the images:%s' % (img))
            images.append(img)
            labels.append(idx)
    image_list = np.asarray(images, np.string_)
    label_list = np.asarray(labels, np.int32)

    # shuffle
    num_example = image_list.shape[0]
    arr = np.arange(num_example)
    np.random.shuffle(arr)
    image_list = image_list[arr]
    label_list = label_list[arr]

    # divide train_data and val_data
    num_example = image_list.shape[0]
    split = np.int(num_example * 0.8)
    train_images = image_list[:split]
    train_labels = label_list[:split]
    val_images = image_list[split:]
    val_labels = label_list[split:]
    return train_images, train_labels, val_images, val_labels


if __name__ == '__main__':
    train_images, train_labels, val_images, val_labels = get_file('/home/danny/chenwei/CSDN_blog/VGG/datasets/')
    convert_to_tfrecord(images=train_images, labels=train_labels, filename="train.tfrecords")
    convert_to_tfrecord(images=val_images, labels=val_labels, filename="test.tfrecords")

train.py:用于训练的文件,与之间不同之处在于使用队列的方式多线程取数据进行训练。

# -*- coding: utf-8 -*-
import tensorflow as tf
from utils.tfrecords import *
from model.vgg import *

tf.app.flags.DEFINE_integer('num_classes', 5, 'classification number.')
tf.app.flags.DEFINE_integer('crop_width', 256, 'width of input image.')
tf.app.flags.DEFINE_integer('crop_height', 256, 'height of input image.')
tf.app.flags.DEFINE_integer('channels', 3, 'channel number of image.')
tf.app.flags.DEFINE_integer('batch_size', 2, 'num of each batch')
tf.app.flags.DEFINE_integer('num_epochs', 400, 'number of epoch')
tf.app.flags.DEFINE_bool('continue_training', False, 'whether is continue training')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
tf.app.flags.DEFINE_string('dataset_path', './datasets/', 'path of dataset')
tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints')
tf.app.flags.DEFINE_string('train_tfrecords', '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/train.tfrecords', 'train tfrecord')
tf.app.flags.DEFINE_string('test_tfrecords', '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/test.tfrecords', 'test tfrecord')

FLAGS = tf.app.flags.FLAGS

def main(_):
    # data process
    train_images, train_labels = read_and_decode(FLAGS.train_tfrecords, True)
    val_images, val_labels = read_and_decode(FLAGS.test_tfrecords, False)

    train_labels = tf.one_hot(indices=tf.cast(train_labels, tf.int32), depth=FLAGS.num_classes)
    train_images_batch, train_labels_batch = tf.train.shuffle_batch([train_images, train_labels], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=3000, num_threads=16)  # 这里设置线程数

    val_labels = tf.one_hot(indices=tf.cast(val_labels, tf.int32), depth=FLAGS.num_classes)
    val_images_batch, val_labels_batch = tf.train.shuffle_batch([val_images, val_labels], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=3000, num_threads=16)  # 这里设置线程数

    # define network input
    input = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.crop_height, FLAGS.crop_width, FLAGS.channels], name='input')
    output = tf.placeholder(tf.int32, shape=[FLAGS.batch_size, FLAGS.num_classes], name='output')

    # control GPU resource utilization
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # build network
    logits = build_vgg(input, FLAGS.num_classes, 0.5, True)

    # loss
    cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))
    regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    loss = cross_entropy_loss + regularization_loss

    # optimizer
    train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)

    # calculate correct
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    with sess.as_default():

        # init all paramters
        saver = tf.train.Saver(max_to_keep=1000)
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        # restore weight
        if FLAGS.continue_training:
            saver.restore(sess, FLAGS.checkpoints)

        # begin training
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        epoch = 0
        try:
            while not coord.should_stop():
                # begin training
                train_images, train_labels = sess.run([train_images_batch, train_labels_batch])
                _, err, acc = sess.run([train_op, loss, accuracy], feed_dict={input: train_images, output: train_labels})
                print("[Train] Step: %d, loss: %.4f, accuracy: %.4f%%" % (epoch, err, acc))
                epoch += 1

                if epoch % 10 == 0 or (epoch + 1) == FLAGS.num_epochs:
                    val_images, val_labels = sess.run([val_images_batch, val_labels_batch])
                    val_err, val_acc = sess.run([loss, accuracy], feed_dict={input:val_imagesh, output: val_labels})
                    print("[validation] Step: %d, loss: %.4f, accuracy: %.4f%%" % (epoch, val_err, val_acc))

                if (epoch + 1) == FLAGS.num_epochs:
                    checkpoint_path = FLAGS.checkpoints
                    saver.save(sess, save_path=checkpoint_path, global_step=epoch)
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limited reached')
        finally:
            coord.request_stop()

        coord.join(threads)
        sess.close()


if __name__ == '__main__':
    tf.app.run()

训练结果:大约在96%左右

[Train] Step: 19985, loss: 1.1098, accuracy: 1.0000%
[Train] Step: 19986, loss: 1.1302, accuracy: 1.0000%
[Train] Step: 19987, loss: 1.1232, accuracy: 1.0000%
[Train] Step: 19988, loss: 1.1299, accuracy: 1.0000%
[Train] Step: 19989, loss: 1.1220, accuracy: 1.0000%
[validation] Step: 19990, loss: 1.1634, accuracy: 0.9688%

猜你喜欢

转载自blog.csdn.net/neil3611244/article/details/81210024