机器学习:Fine tune 神经网络 Mobilenet V2

今天介绍如何利用 Tensorflow 和预训练的模型进行 fine-tune,做其他的分类任务,这里以 mobinet V2 为例。

首先,加载相关模块

import tensorflow as tf
import numpy as np
import glob

from nets.mobilenet import mobilenet_v2

slim = tf.contrib.slim

定义一个预处理函数

def mobi_parse_fun(x_in, y_in):
    img_path = tf.read_file(x_in)
    img_decode = tf.io.decode_jpeg(img_path, channels=3)
    img = tf.image.resize_images(img_decode, [224, 224])
    img = tf.cast(img, tf.float32) / 127.5 - 1.0

    return img, y_in

定义 fine-tune 的参数,比如类别,训练的 batch size,训练周期数等

batch_size = 10
Num_class  = 10
learning_rate = 0.001
train_epoches = 100

利用 TensorFlow 中的 dataset API,进行数据流的加载

X_in = tf.placeholder(tf.string, shape=[None])
Y_in = tf.placeholder(tf.int32, shape=[None])

train_data = tf.data.Dataset.from_tensor_slices((X_in, Y_in))
train_data = train_data.map(mobi_parse_fun)
train_data = train_data.batch(batch_size).repeat()

iter_ = tf.data.Iterator.from_structure(train_data.output_types,
                                        train_data.output_shapes)

x_batch, y_batch = iter_.get_next()
y_cast = tf.cast(y_batch, tf.int32)
y_ohe = tf.one_hot(y_cast Num_class)

train_init_op = iter_.make_initializer(train_data)

with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):
    y_score, endpoints = mobilenet_v2.mobilenet(x_batch, num_classes=Num_class, depth_multiplier=1.4)

定义我们要 fine tune 的层数和参数

variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['Mobilenet_V2/Logits'])
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)

fc_variables = tf.contrib.framework.get_variables('Mobilenet_V2/Logits')
fc_init = tf.variables_initializer(fc_variables)

Net_name = 'Mobilenet_V2/'
layer_name = ['Logits', 'expanded_conv_16']
train_var_list = []

for name_ in layer_name:
    fc_var = tf.contrib.framework.get_variables(Net_name + name_)
    train_var_list.append(fc_var)

定义交叉熵,loss 函数,以及优化器

entropy  = tf.nn.softmax_cross_entropy_with_logits(labels = y_one_hot, logits = y_score)
loss = tf.reduce_mean(entropy)

fc_optimizer = tf.train.GradientDescentOptimizer(learning_rate)

# 只 fine tune 最后的 FC 层
fc_train_op = fc_optimizer.minimize(loss, var_list=fc_variables)

# fine tune FC 层和某几层卷积层
# fc_train_op = fc8_optimizer.minimize(loss, var_list=train_var_list)

y_pro = tf.nn.softmax(y_score)
prediction = tf.to_int32(tf.argmax(y_pro, 1))
correct_prediction = tf.equal(prediction, tf.argmax(y_ohe, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

加载预训练模型

model_path = 'D:\Python_Code\Models\mobilenet_v2_1.4_224\mobilenet_v2_1.4_224.ckpt'
saver = tf.train.Saver()

设置 session,开始 fine tune ,并且将fine tune 后的模型进行存储

with tf.Session() as sess:

    init_fn(sess)
    sess.run(fc_init)

    sess.run(train_init_op, feed_dict={X_in: train_img_list, Y_in: train_label})

    train_num = len(train_img_list)
    train_batches = train_num // batch_size

    for epoch_ in range(train_epoches):

        total_loss = 0

        for batch_id in train_batches:

            loss_, _ = sess.run([loss, fc_train_op])
            total_loss = total_loss + loss_

        print('the total loss is: ', total_loss)

    saver.save(sess, 'ft_model', global_step = 100)

发布了219 篇原创文章 · 获赞 211 · 访问量 113万+

猜你喜欢

转载自blog.csdn.net/shinian1987/article/details/88805496