用Tensorflow 2.0实现Imagenet的训练

Tensorflow 2.0正式版10月份正式推出了,我也第一时间转向了这个新的版本,花了一些时间研究之后,我的结论是2.0版本确实是挺简便易用的,不过也有个缺点是封装的太好了,你无法很好的理解里面实现的机制,例如我尝试了2.0推荐的Keras搭建模型和训练的方法后,发现并没有之前1.x版本用低阶API直接训练收敛的快,而且似乎也达不到1.x的训练精度。例如采用了Batch Normalization的Keras层,如果是直接用Keras的Fit方法来训练和验证,在官方文档中没有提到如何来区分训练和预测,实际的测试之中发现模型收敛的很慢,查了网上的一些帖子也提到类似的问题,解决方案是调用Keras.backend.set_learning_phase,不过我发现调用与否差别不大,可能是网上帖子是基于TF 2.0的测试版本,和正式版本的行为有所不同。之后我也改用了Custom Training Loop的方式做比较,发现比直接model fit的方式似乎收敛的要快和更稳定一些,不过好像还是达不到1.x的精度。虽然目前我还没很好的运用TF2.0,不过整体感觉TF 2.0的易用性还是大大增强的,还是值得继续深入研究下去。下面记录一下我用TF 2.0进行Imagenet的训练的过程。Imagenet的文件依然按照我之前博客提到的方式来生成训练集和验证集,这里不再重复。

模型定义

我采用的是MobileNet V2的模型,如以下代码:

import tensorflow as tf
l = tf.keras.layers
imageWidth = 224
imageHeight = 224

def _conv(inputs, filters, kernel_size, strides, padding, bias=False, normalize=True, activation='relu'):
    output = inputs
    padding_str = 'same'
    if padding>0:
        output = l.ZeroPadding2D(padding=padding)(output)
        padding_str = 'valid'
    output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \
                      kernel_initializer='he_normal', \
                      kernel_regularizer=tf.keras.regularizers.l2(l=5e-4))(output)
    if normalize:
        output = l.BatchNormalization(axis=3)(output)
    if activation=='relu':
        output = l.ReLU()(output)
    if activation=='relu6':
        output = l.ReLU(max_value=6)(output)
    if activation=='leaky_relu':
        output = l.LeakyReLU(alpha=0.1)(output)
    return output
 
def _dwconv(inputs, filters, kernel_size, strides, padding, bias=False, activation='relu'):
    output = inputs
    padding_str = 'same'
    if padding>0:
        output = l.ZeroPadding2D(padding=(padding, padding))(output)
        padding_str = 'valid'
    output = l.DepthwiseConv2D(kernel_size, strides, padding_str, use_bias=bias, \
                               depthwise_initializer='he_uniform', depthwise_regularizer=tf.keras.regularizers.l2(l=5e-4))(output)
    output = l.BatchNormalization(axis=3)(output)
    if activation=='relu':
        output = l.ReLU()(output)
    if activation=='relu6':
        output = l.ReLU(max_value=6)(output)
    if activation=='leaky_relu':
        output = l.LeakyReLU(alpha=0.1)(output)
    return output
 
def _bottleneck(inputs, in_filters, out_filters, kernel_size, strides, bias=False, activation='relu6', t=1):
    output = inputs
    output = _conv(output, in_filters*t, 1, 1, 0, False, activation)
    padding = 0
    if strides == 2:
        padding = 1
    output = _dwconv(output, in_filters*t, kernel_size, strides, padding, bias=False, activation=activation)
    output = _conv(output, out_filters, 1, 1, 0, False, 'linear')
    if strides==1 and inputs.get_shape().as_list()[3]==output.get_shape().as_list()[3]:
        output = l.add([output, inputs])
    return output

def mobilenet_model_v2():
    # Input Layer
    image = tf.keras.Input(shape=(imageHeight,imageWidth,3))   #224*224*3
    net = _conv(image, 32, 3, 2, 1, False, 'relu6')            #112*112*32
    net = _bottleneck(net, 32, 16, 3, 1, False, 'relu6', 1)    #112*112*16
    net = _bottleneck(net, 16, 24, 3, 2, False, 'relu6', 6)    #56*56*24
    net = _bottleneck(net, 24, 24, 3, 1, False, 'relu6', 6)    #56*56*24
    net = _bottleneck(net, 24, 32, 3, 2, False, 'relu6', 6)    #28*28*32
    net = _bottleneck(net, 32, 32, 3, 1, False, 'relu6', 6)    #28*28*32
    net = _bottleneck(net, 32, 32, 3, 1, False, 'relu6', 6)    #28*28*32
    net = _bottleneck(net, 32, 64, 3, 2, False, 'relu6', 6)    #14*14*64
    net = _bottleneck(net, 64, 64, 3, 1, False, 'relu6', 6)    #14*14*64
    net = _bottleneck(net, 64, 64, 3, 1, False, 'relu6', 6)    #14*14*64
    net = _bottleneck(net, 64, 64, 3, 1, False, 'relu6', 6)    #14*14*64
    net = _bottleneck(net, 64, 96, 3, 1, False, 'relu6', 6)    #14*14*96
    net = _bottleneck(net, 96, 96, 3, 1, False, 'relu6', 6)    #14*14*96
    net = _bottleneck(net, 96, 96, 3, 1, False, 'relu6', 6)    #14*14*96
    net = _bottleneck(net, 96, 96, 3, 1, False, 'relu6', 6)    #14*14*96
    net = _bottleneck(net, 96, 160, 3, 2, False, 'relu6', 6)   #7*7*160
    net = _bottleneck(net, 160, 160, 3, 1, False, 'relu6', 6)  #7*7*160
    net = _bottleneck(net, 160, 160, 3, 1, False, 'relu6', 6)  #7*7*160
    net = _bottleneck(net, 160, 320, 3, 1, False, 'relu6', 6)  #7*7*320
    net = _conv(net, 1280, 3, 1, 0, False, 'relu6')            #7*7*1280
    net = l.AveragePooling2D(7)(net)
    net = l.Flatten()(net)
    logits = l.Dense(1000, kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1/1000))(net)
    model = tf.keras.Model(inputs=image, outputs=logits)
    return model 

构建训练集和验证集

imageDepth = 3
batch_size = 64
resize_min = 256
train_files_names = os.listdir('/AI/train_tf/')
train_files = ['/AI/train_tf/'+item for item in train_files_names]
valid_files_names = os.listdir('/AI/valid_tf/')
valid_files = ['/AI/valid_tf/'+item for item in valid_files_names]

# Parse TFRECORD and distort the image for train
def _parse_function(example_proto):
    features = {"image": tf.io.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.io.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.io.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.io.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.io.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.io.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.io.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.io.VarLenFeature(tf.float32),
                "bbox_xmax": tf.io.VarLenFeature(tf.float32),
                "bbox_ymin": tf.io.VarLenFeature(tf.float32),
                "bbox_ymax": tf.io.VarLenFeature(tf.float32),
                "text": tf.io.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.io.FixedLenFeature([], tf.string, default_value="")
               }
    parsed_features = tf.io.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    # Random resize the image 
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
    resized = tf.image.resize(image_float, [resized_height, resized_width])
    # Random crop from the resized image
    cropped = tf.image.random_crop(resized, [imageHeight, imageWidth, 3])
    # Flip to add a little more random distortion in.
    flipped = tf.image.random_flip_left_right(cropped)
    # Standardization the image
    image_train = tf.image.per_image_standardization(flipped)
    image_train = tf.transpose(image_train, perm=[2, 0, 1])
    features = {'input_1': image_train}
    return features, parsed_features["label"][0]
 
def train_input_fn():
    dataset_train = tf.data.TFRecordDataset(train_files)
    dataset_train = dataset_train.map(_parse_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset_train = dataset_train.shuffle(10000)
    dataset_train = dataset_train.repeat(10)
    dataset_train = dataset_train.batch(batch_size)
    dataset_train = dataset_train.prefetch(batch_size)
    return dataset_train

def _parse_test_function(example_proto):
    features = {"image": tf.io.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.io.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.io.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.io.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.io.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.io.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.io.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.io.VarLenFeature(tf.float32),
                "bbox_xmax": tf.io.VarLenFeature(tf.float32),
                "bbox_ymin": tf.io.VarLenFeature(tf.float32),
                "bbox_ymax": tf.io.VarLenFeature(tf.float32),
                "text": tf.io.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.io.FixedLenFeature([], tf.string, default_value="")
               }
    parsed_features = tf.io.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
    image_resized = tf.image.resize(image_float, [resized_height, resized_width])
    
    # calculate how many to be center crop
    shape = tf.shape(image_resized)  
    height, width = shape[0], shape[1]
    amount_to_be_cropped_h = (height - imageHeight)
    crop_top = amount_to_be_cropped_h // 2
    amount_to_be_cropped_w = (width - imageWidth)
    crop_left = amount_to_be_cropped_w // 2
    image_cropped = tf.slice(image_resized, [crop_top, crop_left, 0], [imageHeight, imageWidth, -1])
    image_valid = tf.image.per_image_standardization(image_cropped)
    image_valid = tf.transpose(image_valid, perm=[2, 0, 1])
    features = {'input_1': image_valid}
    return features, parsed_features["label"][0]
 
def val_input_fn():
    dataset_valid = tf.data.TFRecordDataset(valid_files)
    dataset_valid = dataset_valid.map(_parse_test_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset_valid = dataset_valid.batch(batch_size)
    dataset_valid = dataset_valid.prefetch(batch_size)
    return dataset_valid

定义模型的回调函数

主要作用是根据训练的步数来调整优化器的学习率,以及在每个训练EPOCH完成后打印输出验证集的指标,如以下代码:

boundaries = [1000, 5000, 60000, 80000]
values = [0.001, 0.1, 0.01, 0.001, 0.0001]
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)

class LRCallback(tf.keras.callbacks.Callback):
    def __init__(self, starttime):
        super(LRCallback, self).__init__()
        self.epoch_starttime = starttime
        self.batch_starttime = starttime
    def on_train_batch_end(self, batch, logs):
        step = tf.keras.backend.get_value(self.model.optimizer.iterations)
        if step%100==0:
            elasp_time = time.time()-self.batch_starttime
            self.batch_starttime = time.time()
            lr = tf.keras.backend.get_value(self.model.optimizer.lr)
            tf.keras.backend.set_value(self.model.optimizer.lr, learning_rate_fn(step))
            print("Steps:{}, LR:{:6.4f}, Loss:{:4.2f}, Time:{:4.1f}s"\
                  .format(step, lr, logs['loss'], elasp_time))
    def on_epoch_end(self, epoch, logs=None):
        epoch_elasp_time = time.time()-self.epoch_starttime
        print("Epoch:{}, Top-1 Accuracy:{:5.3f}, Top-5 Accuracy:{:5.3f}, Time:{:5.1f}s"\
              .format(epoch, logs['val_top_1_accuracy'], logs['val_top_5_accuracy'], epoch_elasp_time))
    def on_epoch_begin(self, epoch, logs=None):
        tf.keras.backend.set_learning_phase(True)
        self.epoch_starttime=time.time()
    def on_test_begin(self, logs=None):
        tf.keras.backend.set_learning_phase(False)

tensorboard_cbk = tf.keras.callbacks.TensorBoard(log_dir='mobilenet/logs')
checkpoint_cbk = tf.keras.callbacks.ModelCheckpoint(filepath='mobilenet/test_{epoch}.h5', verbose=1)

编译模型

对模型进行编译,定义LOSS函数,选择优化器,选择验证指标。

model = mobilenet_model_v2()
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='top_1_accuracy'),
                       tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top_5_accuracy')])

训练和验证模型

最后就是开始进行训练和验证了,注意里面的Callbacks填入我们之前定义的回调函数,可以很方便的帮我们调整学习率,打印验证结果,保存模型。之后如果需要加载模型,只需调用tf.keras.models.load_model即可,不需要再对模型进行编译。

train_data = train_input_fn()
val_data = val_input_fn()
_ = model.fit(train_data,
              validation_data=val_data,
              epochs=2,
              verbose=0,
              callbacks=[LRCallback(time.time()), tensorboard_cbk, checkpoint_cbk],
              steps_per_epoch=5000)

自定义训练过程(Custom Training Loop)

从以上的代码可见,用Keras Model Compile和Fit的方式可以很方便的对模型进行训练,唯一美中不足的是,我发现这个过程封装的太黑盒子了,里面的一些细节都别掩盖掉了,如果你需要对训练的过程做一些额外的控制的话可能不太方便(当然理论上应该也可以在回调函数中来做),不过对我来说,最大的问题是模型训练时似乎收敛的太慢了,最后的精度也不是很令人满意,具体的原因我还不是特别确定。为此我也特意用Custom Training Loop的方式来写了一下,进行对比,如果用这种方式,那么以上代码从模型的编译开始,将被以下的代码所替代,可见代码量稍微多一些,不过从我实际训练的效果来看似乎要更好一些:

train_data = train_input_fn()
val_data = val_input_fn()
START_EPOCH = 0
NUM_EPOCH = 1
STEPS_EPOCH = 0
STEPS_OFFSET = 0
with tf.device('/GPU:0'):
    model = mobilenet_model_v2()
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)
    #model = tf.keras.models.load_model('model/darknet53_custom_training_12.h5')
    @tf.function
    def train_step(inputs, labels):
        with tf.GradientTape() as tape:
            predictions = model(inputs, training=True)
            regularization_loss = tf.math.add_n(model.losses)
            pred_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(labels, predictions)
            total_loss = pred_loss + regularization_loss
        gradients = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return total_loss

    boundaries = [1000, 5000, 65000, 100000]
    values = [0.001, 0.1, 0.01, 0.001, 0.0001]
    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)

    for epoch in range(NUM_EPOCH):
        start_step = tf.keras.backend.get_value(optimizer.iterations)+STEPS_OFFSET
        steps = start_step
        loss_sum = 0
        start_time = time.time()
        for inputs, labels in train_data:
            if (steps-start_step)>STEPS_EPOCH:
                break
            loss_sum += train_step(inputs, labels)
            steps = tf.keras.backend.get_value(optimizer.iterations)+STEPS_OFFSET
            if steps%100 == 0:
                elasp_time = time.time()-start_time
                lr = tf.keras.backend.get_value(optimizer.lr)
                print("Step:{}, Loss:{:4.2f}, LR:{:5f}, Time:{:3.1f}s".format(steps, loss_sum/100, lr, elasp_time))
                loss_sum = 0
                tf.keras.backend.set_value(optimizer.lr, learning_rate_fn(steps))
                start_time = time.time()
            steps += 1
        model.save('model/darknet53_custom_training_'+str(START_EPOCH+epoch)+'.h5')
        m1 = tf.keras.metrics.SparseCategoricalAccuracy()
        m2 = tf.keras.metrics.SparseTopKCategoricalAccuracy()
        for inputs, labels in val_data:
            val_predict_logits = model(inputs, training=False)
            val_predict = tf.keras.activations.softmax(val_predict_logits)
            m1.update_state(labels, val_predict)        
            m2.update_state(labels, val_predict)  
        print("Top-1 Accuracy:%f, Top-2 Accuracy:%f"%(m1.result().numpy(),m2.result().numpy()))
        m1.reset_states()
        m2.reset_states()

猜你喜欢

转载自blog.csdn.net/gzroy/article/details/102596014