TensorFlow学习实践(一):使用TFRecord格式数据和队列进行模型训练和预测

本文以mnist为例,介绍如何使用TFRecord格式数据和队列进行模型训练和预测。

参考:

1、cifar10

2、https://tensorflow.google.cn/guide/datasets

TFRecord格式数据的制作参见将mnist数据转成原始图片数据再转成TFRecord格式

目录

一、输入数据的解析和预处理

二、定义模型

三、计算损失并定义训练操作

四、模型训练

五、模型验证

六、对单张图片进行预测


一、输入数据的解析和预处理

def read_mnist_tfrecords(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(serialized_example, features={
        'img_raw': tf.FixedLenFeature([], tf.string, ''),
        'label': tf.FixedLenFeature([], tf.int64, 0)
    })
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    label = tf.cast(features['label'], tf.int64)
    image = tf.reshape(image, [FLAGS.image_height, FLAGS.image_width, 1])
    return image, label


def inputs(filenames, examples_num, batch_size, shuffle):
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    with tf.name_scope('inputs'):
        filename_queue = tf.train.string_input_producer(filenames)
        image, label = read_mnist_tfrecords(filename_queue)
        image = tf.cast(image, tf.float32)
        min_fraction_of_examples_in_queue = 0.4
        min_queue_examples = int(min_fraction_of_examples_in_queue * examples_num)
        num_process_threads = 16
        if shuffle:
            images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size,
                                                    num_threads=num_process_threads,
                                                    capacity=min_queue_examples + batch_size * 3,
                                                    min_after_dequeue=min_queue_examples)
        else:
            images, labels = tf.train.batch([image, label], batch_size=batch_size,
                                            num_threads=num_process_threads,
                                            capacity=min_queue_examples + batch_size * 3)
        return images, labels

处理之后,返回的是批量的image和对应的label。

二、定义模型

def inference(images, training):
    with tf.variable_scope('conv1'):
        conv1 = tf.layers.conv2d(inputs=images,
                                 filters=32,
                                 kernel_size=[5, 5],
                                 padding='same',
                                 activation=tf.nn.relu)

    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)      # 14*14*32

    with tf.variable_scope('conv2'):
        conv2 = tf.layers.conv2d(inputs=pool1,
                                 filters=64,
                                 kernel_size=[5, 5],
                                 padding='same',
                                 activation=tf.nn.relu)

    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)      # 7*7*64

    with tf.variable_scope('fc1'):
        pool2_flat = tf.reshape(pool2, [-1, 7*7*64])
        fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
        dropout1 = tf.layers.dropout(inputs=fc1, rate=0.4, training=training)

    with tf.variable_scope('logits'):
        logits = tf.layers.dense(inputs=dropout1, units=10)     # 使用该值计算交叉熵损失
        predict = tf.nn.softmax(logits)

    return logits, predict

模型定义采用tf.layers API,返回值中的logits用于计算损失。

三、计算损失并定义训练操作

def loss(logits, labels):
    labels = tf.cast(labels, tf.int64)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='cross_entropy')
    cross_entropy_loss = tf.reduce_mean(cross_entropy)
    return cross_entropy_loss


def train(total_loss, global_step):
    num_batches_per_epoch = TRAIN_EXAMPLES_NUM / FLAGS.batch_size
    decay_steps = int(num_batches_per_epoch * 10)

    # Decay the learning rate exponentially based on the number of steps.
    lr = tf.train.exponential_decay(learning_rate=0.001,
                                    global_step=global_step,
                                    decay_steps=decay_steps,
                                    decay_rate=0.1,
                                    staircase=True)

    # opt = tf.train.GradientDescentOptimizer(lr)
    # opt = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.99)
    opt = tf.train.AdamOptimizer(learning_rate=lr)
    grad = opt.compute_gradients(total_loss)
    apply_grad_op = opt.apply_gradients(grad, global_step)

    return apply_grad_op

学习率初始值为0.001,每过10个epoch衰减一次,变成上次的1/10.

四、模型训练

def train():
    images, labels = mnist.inputs(['train_img.tfrecords'], mnist.TRAIN_EXAMPLES_NUM,
                                  FLAGS.batch_size, shuffle=True)
    global_step = tf.train.get_or_create_global_step()

    logits, pred = mnist.inference(images, training=True)
    loss = mnist.loss(logits, labels)
    train_op = mnist.train(loss, global_step)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.group(
            tf.local_variables_initializer(),
            tf.global_variables_initializer())
        sess.run(init_op)
        ckpt = os.path.join(FLAGS.train_dir, 'model.ckpt')

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)

        for i in range(1, FLAGS.max_step + 1):
            _, train_loss, predict, label = sess.run([train_op, loss, pred, labels])
            # print(predict, '\n', label)
            if i % 100 == 0:
                print('step: {}, loss: {}'.format(i, train_loss))
                # print(predict, '\n', label)
                saver.save(sess, ckpt, global_step=i)

        coord.request_stop()
        coord.join(threads)

训练时通过参数对数据进行shuffle处理。注意调用tf.train.start_queue_runners(sess, coord=coord),否则队列不会启动,程序会一直卡着。

五、模型验证

def eval_once(saver, top_k_op):
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
            print('no checkpoint file')
            return

        coord = tf.train.Coordinator()
        try:
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            iter_per_epoch = int(math.ceil(mnist.VALIDATION_EXAMPLES_NUM / FLAGS.batch_size))

            total_sample = iter_per_epoch * FLAGS.batch_size
            correct_predict = 0
            step = 0

            while step < iter_per_epoch and not coord.should_stop():
                predict = sess.run(top_k_op)
                correct_predict += np.sum(predict)
                step += 1

            precision = correct_predict / total_sample
            print('step: {}, model: {}, precision: {}'.format(global_step, ckpt.model_checkpoint_path, precision))

        except Exception as e:
            print('exception: ', e)
            coord.request_stop(e)
        finally:
            coord.request_stop()
        coord.join(threads)


def evaluation():
    images, labels = mnist.inputs(['./validation_img.tfrecords'], mnist.VALIDATION_EXAMPLES_NUM,
                                  batch_size=FLAGS.batch_size, shuffle=False)
    logits, pred = mnist.inference(images, training=False)
    top_k_op = tf.nn.in_top_k(logits, labels, 1)

    saver = tf.train.Saver()

    while True:
        eval_once(saver, top_k_op)
        if FLAGS.run_once:
            break
        time.sleep(FLAGS.eval_interval_secs)

模型验证时不用对数据进行shuffle

六、对单张图片进行预测

def pred(filename, train_dir):
    img = cv2.imread(filename, flags=cv2.IMREAD_GRAYSCALE)
    img = tf.cast(img, tf.float32)
    img = tf.reshape(img, [-1, 28, 28, 1])

    logits, predict = mnist.inference(img, training=False)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('no checkpoint file')
            return
        pre = sess.run(predict)
        print('model:{}, file:{}, label: {} ({:.2f}%)'.
              format(ckpt.model_checkpoint_path, filename, np.argmax(pre[0]), np.max(pre[0]) * 100))


if __name__ == '__main__':
    pred('./img_test/2_2098.jpg', './train')

输出:

model:./train\model.ckpt-1000, file:./img_test/2_2098.jpg, label: 2 (96.27%)

最后:完整代码

https://github.com/buptlj/learn_tf

猜你喜欢

转载自blog.csdn.net/qiumokucao/article/details/82083253