今天介绍如何利用 Tensorflow 和预训练的模型进行 fine-tune,做其他的分类任务,这里以 mobinet V2 为例。
首先,加载相关模块
import tensorflow as tf
import numpy as np
import glob
from nets.mobilenet import mobilenet_v2
slim = tf.contrib.slim
定义一个预处理函数
def mobi_parse_fun(x_in, y_in):
img_path = tf.read_file(x_in)
img_decode = tf.io.decode_jpeg(img_path, channels=3)
img = tf.image.resize_images(img_decode, [224, 224])
img = tf.cast(img, tf.float32) / 127.5 - 1.0
return img, y_in
定义 fine-tune 的参数,比如类别,训练的 batch size,训练周期数等
batch_size = 10
Num_class = 10
learning_rate = 0.001
train_epoches = 100
利用 TensorFlow 中的 dataset API,进行数据流的加载
X_in = tf.placeholder(tf.string, shape=[None])
Y_in = tf.placeholder(tf.int32, shape=[None])
train_data = tf.data.Dataset.from_tensor_slices((X_in, Y_in))
train_data = train_data.map(mobi_parse_fun)
train_data = train_data.batch(batch_size).repeat()
iter_ = tf.data.Iterator.from_structure(train_data.output_types,
train_data.output_shapes)
x_batch, y_batch = iter_.get_next()
y_cast = tf.cast(y_batch, tf.int32)
y_ohe = tf.one_hot(y_cast Num_class)
train_init_op = iter_.make_initializer(train_data)
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):
y_score, endpoints = mobilenet_v2.mobilenet(x_batch, num_classes=Num_class, depth_multiplier=1.4)
定义我们要 fine tune 的层数和参数
variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['Mobilenet_V2/Logits'])
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)
fc_variables = tf.contrib.framework.get_variables('Mobilenet_V2/Logits')
fc_init = tf.variables_initializer(fc_variables)
Net_name = 'Mobilenet_V2/'
layer_name = ['Logits', 'expanded_conv_16']
train_var_list = []
for name_ in layer_name:
fc_var = tf.contrib.framework.get_variables(Net_name + name_)
train_var_list.append(fc_var)
定义交叉熵,loss 函数,以及优化器
entropy = tf.nn.softmax_cross_entropy_with_logits(labels = y_one_hot, logits = y_score)
loss = tf.reduce_mean(entropy)
fc_optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 只 fine tune 最后的 FC 层
fc_train_op = fc_optimizer.minimize(loss, var_list=fc_variables)
# fine tune FC 层和某几层卷积层
# fc_train_op = fc8_optimizer.minimize(loss, var_list=train_var_list)
y_pro = tf.nn.softmax(y_score)
prediction = tf.to_int32(tf.argmax(y_pro, 1))
correct_prediction = tf.equal(prediction, tf.argmax(y_ohe, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
加载预训练模型
model_path = 'D:\Python_Code\Models\mobilenet_v2_1.4_224\mobilenet_v2_1.4_224.ckpt'
saver = tf.train.Saver()
设置 session,开始 fine tune ,并且将fine tune 后的模型进行存储
with tf.Session() as sess:
init_fn(sess)
sess.run(fc_init)
sess.run(train_init_op, feed_dict={X_in: train_img_list, Y_in: train_label})
train_num = len(train_img_list)
train_batches = train_num // batch_size
for epoch_ in range(train_epoches):
total_loss = 0
for batch_id in train_batches:
loss_, _ = sess.run([loss, fc_train_op])
total_loss = total_loss + loss_
print('the total loss is: ', total_loss)
saver.save(sess, 'ft_model', global_step = 100)