Tensorflow Estimator implements ImageNet image recognition

Tensorflow has introduced Estimator from version 1.3, and with the evolution of the version, it has increased its support for this advanced API programming method, and it is very convenient to implement support for multi-GPU training on Estimator. In my previous blog, I used low-level API to build and train the model. The advantage is that it is more flexible and can understand the underlying details of the model, but the disadvantage is that the amount of code is large, and it is more cumbersome, with many details. You need to do it yourself. For this reason, I tried to use the high-level API on the latest TensorFlow 1.14 version to test whether it is really easy to use and can achieve the same performance as the low-level API.

I am testing based on ImageNet's image classification data. For Imagenet data preparation, please refer to my previous blog. The specific code is as follows, which includes 2 models, one is the pre-trained model Darknet53 used in Yolo V3, and the other is Alexnet.

import tensorflow as tf
import horovod.tensorflow as hvd
import os
import random
import time
import numpy as np

imageWidth = 224
imageHeight = 224
imageDepth = 3
batch_size = 128
resize_min = 256

train_files_names = os.listdir('/data/AI/train_tf/')
train_files = ['/data/AI/train_tf/'+item for item in train_files_names]
valid_files_names = os.listdir('/data/AI/valid_tf/')
valid_files = ['/data/AI/valid_tf/'+item for item in valid_files_names]

# Parse TFRECORD and distort the image for train
def _parse_function(example_proto):
    features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.VarLenFeature(tf.float32),
                "bbox_xmax": tf.VarLenFeature(tf.float32),
                "bbox_ymin": tf.VarLenFeature(tf.float32),
                "bbox_ymax": tf.VarLenFeature(tf.float32),
                "text": tf.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.FixedLenFeature([], tf.string, default_value="")
               }
    parsed_features = tf.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    # Random resize the image 
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
    resized = tf.image.resize_images(image_float, [resized_height, resized_width])
    # Random crop from the resized image
    cropped = tf.random_crop(resized, [imageHeight, imageWidth, 3])
    # Flip to add a little more random distortion in.
    flipped = tf.image.random_flip_left_right(cropped)
    # Standardization the image
    #image_train = flipped
    image_train = tf.image.per_image_standardization(flipped)
    features = {'images': image_train}
    #return image_train, tf.one_hot(parsed_features["label"][0], 1000)
    return features, parsed_features["label"][0]

def train_input_fn():
    dataset_train = tf.data.TFRecordDataset(train_files)
    dataset_train = dataset_train.map(_parse_function, num_parallel_calls=4)
    dataset_train = dataset_train.repeat(10)
    dataset_train = dataset_train.batch(batch_size)
    dataset_train = dataset_train.prefetch(batch_size)
    return dataset_train

def _parse_test_function(example_proto):
    features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.VarLenFeature(tf.float32),
                "bbox_xmax": tf.VarLenFeature(tf.float32),
                "bbox_ymin": tf.VarLenFeature(tf.float32),
                "bbox_ymax": tf.VarLenFeature(tf.float32),
                "text": tf.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.FixedLenFeature([], tf.string, default_value="")
               }
    parsed_features = tf.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
    image_resized = tf.image.resize_images(image_float, [resized_height, resized_width])
    
    # calculate how many to be center crop
    shape = tf.shape(image_resized)  
    height, width = shape[0], shape[1]
    amount_to_be_cropped_h = (height - imageHeight)
    crop_top = amount_to_be_cropped_h // 2
    amount_to_be_cropped_w = (width - imageWidth)
    crop_left = amount_to_be_cropped_w // 2
    image_cropped = tf.slice(image_resized, [crop_top, crop_left, 0], [imageHeight, imageWidth, -1])
    image_valid = tf.image.per_image_standardization(image_cropped)
    features = {'images': image_valid}
    #return image_valid, tf.one_hot(parsed_features["label"][0], 1000)
    return features, parsed_features["label"][0]

def val_input_fn():
    dataset_valid = tf.data.TFRecordDataset(valid_files)
    dataset_valid = dataset_valid.map(_parse_test_function, num_parallel_calls=4)
    dataset_valid = dataset_valid.batch(batch_size)
    dataset_valid = dataset_valid.prefetch(batch_size)
    return dataset_valid

def predict_input_fn():
    dataset_valid = tf.data.TFRecordDataset(valid_files)
    dataset_valid = dataset_valid.map(_parse_test_function, num_parallel_calls=4)
    dataset_valid = dataset_valid.take(batch_size)
    dataset_valid = dataset_valid.batch(batch_size)
    return dataset_valid

l = tf.keras.layers
def _conv(inputs, filters, kernel_size, strides, padding, bias=False, normalize=True, activation='relu'):
    output = inputs
    padding_str = 'same'
    if padding>0:
        output = l.ZeroPadding2D(padding=padding)(output)
        padding_str = 'valid'
    output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \
                      kernel_initializer='he_normal', \
                      kernel_regularizer=tf.keras.regularizers.l2(l=5e-4))(output)
    if normalize:
        output = l.BatchNormalization(axis=3)(output)
    if activation=='relu':
        output = l.ReLU()(output)
    if activation=='relu6':
        output = l.ReLU(max_value=6)(output)
    if activation=='leaky_relu':
        output = l.LeakyReLU(alpha=0.1)(output)
    return output

def _dwconv(inputs, filters, kernel_size, strides, padding, bias=False, activation='relu'):
    output = inputs
    padding_str = 'same'
    if padding>0:
        output = l.ZeroPadding2D(padding=(padding, padding))(output)
        padding_str = 'valid'
    output = l.DepthwiseConv2D(kernel_size, strides, padding_str, use_bias=bias, \
                               depthwise_initializer='he_uniform', depthwise_regularizer=tf.keras.regularizers.l2(l=5e-4))(output)
    output = l.BatchNormalization(axis=3)(output)
    if activation=='relu':
        output = l.ReLU()(output)
    if activation=='relu6':
        output = l.ReLU(max_value=6)(output)
    if activation=='leaky_relu':
        output = l.LeakyReLU(alpha=0.1)(output)
    return output

def _bottleneck(inputs, in_filters, out_filters, kernel_size, strides, bias=False, activation='relu6', t=1):
    output = inputs
    output = _conv(output, in_filters*t, 1, 1, 0, False, activation)
    padding = 0
    if strides == 2:
        padding = 1
    output = _dwconv(output, in_filters*t, kernel_size, strides, padding, bias=False, activation=activation)
    output = _conv(output, out_filters, 1, 1, 0, False, 'linear')
    if strides==1 and inputs.get_shape().as_list()[3]==output.get_shape().as_list()[3]:
        output = l.add([output, inputs])
    return output

def mobilenet_model_v1():
    # Input Layer
    image = tf.keras.Input(shape=(imageHeight,imageWidth,3))
    net = _conv(image, 32, 3, 2, 1)    
    net = _dwconv(net, 32, 3, 1, 0)          
    net = _conv(net, 64, 1, 1, 0)
    net = _dwconv(net, 64, 3, 2, 1)
    net = _conv(net, 128, 1, 1, 0)
    net = _dwconv(net, 128, 3, 1, 0)
    net = _conv(net, 128, 1, 1, 0)
    net = _dwconv(net, 128, 3, 2, 1)
    net = _conv(net, 256, 1, 1, 0)
    net = _dwconv(net, 256, 3, 1, 0)
    net = _conv(net, 256, 1, 1, 0)
    net = _dwconv(net, 256, 3, 2, 1)
    net = _conv(net, 512, 1, 1, 0)
    for _ in range(5):
        net = _dwconv(net, 512, 3, 1, 0)
        net = _conv(net, 512, 1, 1, 0)
    net = _dwconv(net, 512, 3, 2, 1)
    net = _conv(net, 1024, 1, 1, 0)
    net = _dwconv(net, 1024, 3, 1, 0)
    net = _conv(net, 1024, 1, 1, 0)
    net = l.GlobalAveragePooling2D()(net)
    net = l.Flatten()(net)
    logits = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1/1000))(net)
    model = tf.keras.Model(inputs=image, outputs=logits)
    return model

def mobilenet_model_v2():
    # Input Layer
    image = tf.keras.Input(shape=(imageHeight,imageWidth,3))   #224*224*3
    net = _conv(image, 32, 3, 2, 1, False, 'relu6')            #112*112*32
    net = _bottleneck(net, 32, 16, 3, 1, False, 'relu6', 1)    #112*112*16
    net = _bottleneck(net, 16, 24, 3, 2, False, 'relu6', 6)    #56*56*24
    net = _bottleneck(net, 24, 24, 3, 1, False, 'relu6', 6)    #56*56*24
    net = _bottleneck(net, 24, 32, 3, 2, False, 'relu6', 6)    #28*28*32
    net = _bottleneck(net, 32, 32, 3, 1, False, 'relu6', 6)    #28*28*32
    net = _bottleneck(net, 32, 32, 3, 1, False, 'relu6', 6)    #28*28*32
    net = _bottleneck(net, 32, 64, 3, 2, False, 'relu6', 6)    #14*14*64
    net = _bottleneck(net, 64, 64, 3, 1, False, 'relu6', 6)    #14*14*64
    net = _bottleneck(net, 64, 64, 3, 1, False, 'relu6', 6)    #14*14*64
    net = _bottleneck(net, 64, 64, 3, 1, False, 'relu6', 6)    #14*14*64
    net = _bottleneck(net, 64, 96, 3, 1, False, 'relu6', 6)    #14*14*96
    net = _bottleneck(net, 96, 96, 3, 1, False, 'relu6', 6)    #14*14*96
    net = _bottleneck(net, 96, 96, 3, 1, False, 'relu6', 6)    #14*14*96
    net = _bottleneck(net, 96, 96, 3, 1, False, 'relu6', 6)    #14*14*96
    net = _bottleneck(net, 96, 160, 3, 2, False, 'relu6', 6)   #7*7*160
    net = _bottleneck(net, 160, 160, 3, 1, False, 'relu6', 6)  #7*7*160
    net = _bottleneck(net, 160, 160, 3, 1, False, 'relu6', 6)  #7*7*160
    net = _bottleneck(net, 160, 320, 3, 1, False, 'relu6', 6)  #7*7*320
    net = _conv(net, 1280, 3, 1, 0, False, 'relu6')            #7*7*1280
    net = l.AveragePooling2D(7)(net)
    net = l.Flatten()(net)
    logits = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1/1000))(net)
    model = tf.keras.Model(inputs=image, outputs=logits)
    return model    

def mobilenet(features, labels, mode, params):
    model = mobilenet_model_v2()
    training = (mode == tf.estimator.ModeKeys.TRAIN)
    #features = tf.reshape(features, [-1,imageHeight,imageWidth,3])
    images = tf.reshape(features["images"], [-1, imageHeight, imageWidth, 3])
    logits = model(images, training)

    predictions = {
        # Generate predictions (for PREDICT and EVAL mode)
        "classes": tf.argmax(input=logits, axis=-1),
        # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
        # `logging_hook`.
        "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, \
                                          export_outputs={'classify': tf.estimator.export.PredictOutput(predictions)})

    # Calculate Loss (for both TRAIN and EVAL modes)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    # Configure the Training Op (for TRAIN mode)
    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_global_step()
        boundaries = [5000, 60000, 80000]
        values = [0.1, 0.01, 0.001, 0.0001]
        learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries, values)
        tf.summary.scalar('learning_rate', learning_rate)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(model.get_updates_for(features)):
            train_op = optimizer.minimize(loss=loss, global_step=global_step)
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    # Add evaluation metrics (for EVAL mode)
    m = tf.keras.metrics.sparse_top_k_categorical_accuracy(y_true=labels,  y_pred=logits)
    tf.summary.scalar('top-5_accuracy', m)
    accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])
    tf.summary.scalar('accuracy', accuracy[0]) 
    eval_metric_ops = {
        #"accuracy": tf.metrics.accuracy(labels=true_labels, predictions=predictions["classes"])}
        "accuracy": accuracy} 
        #"top-5 accuracy": (m.result(), m.update_state(y_true=labels, y_pred=logits))}
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

def main(_):
    my_feature_columns = []
    my_feature_columns.append(tf.feature_column.numeric_column(key='images', shape=(imageHeight,imageWidth, 3)))
    imagenet_classifier = tf.estimator.Estimator(model_fn=mobilenet, \
                                                 model_dir="/home/roy/AI/imagenet_model_mobilenet_v2/", \
                                                 params={'feature_columns': my_feature_columns,})
    for _ in range(10):
        imagenet_classifier.train(input_fn=train_input_fn, steps=5000)
        eval_results = imagenet_classifier.evaluate(input_fn=val_input_fn)
        print(eval_results)

if __name__ == "__main__":
  tf.app.run(main)

As can be seen from the above code, it is still very convenient to build a model and conduct training using Estimator. It should be noted that if Batch Normalization is used, Training=TRUE must be specified and updated during training.

Guess you like

Origin blog.csdn.net/gzroy/article/details/98950060