TensorFlow入门教程之6:模型微调(Finetune)

版权声明:人工智能/机器学习/深度学习交流QQ群:811460433 , 微信公众号:程序员深度学习 https://blog.csdn.net/sinat_24143931/article/details/86482374
人工智能/机器学习/深度学习交流QQ群:116270156
也可以扫一扫下面二维码加入微信群,如果二维码失效,可以添加博主个人微信,拉你进群

模型微调主要包括以下两个过程:

  1. 构建图结构,截取目标张量,添加新层;
  2. 加载目标张量权重,训练新层,全局微调;

1. 构建图结构,截取目标张量,添加新层

这个步骤中的图结构,是通过“先构建图结构,再加载权重”方法得到的mobilenet计算图结构。

tf.reset_default_graph()
# 构建计算图
images = tf.placeholder(tf.float32,(None,224,224,3))
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):
    logits, endpoints = mobilenet_v2.mobilenet(images,depth_multiplier=1.4)

# 获取目标张量,添加新层
with tf.variable_scope("finetune_layers"):
    # 获取目标张量,取出mobilenet中指定层的张量
    mobilenet_tensor = tf.get_default_graph().get_tensor_by_name("MobilenetV2/expanded_conv_14/output:0")
    # 将张量向新层传递
    x = tf.layers.Conv2D(filters=256,kernel_size=3,name="conv2d_1")(mobilenet_tensor)
    # 观察新层权重是否更新 tf.summary.histogram("conv2d_1",x)
    x = tf.nn.relu(x,name="relu_1")
    x = tf.layers.Conv2D(filters=256,kernel_size=3,name="conv2d_2")(x)
    x = tf.layers.Conv2D(10,3,name="conv2d_3")(x)
    predictions = tf.reshape(x, (-1,10))

计算图结构:
finetune网络结构

红色框内的是Mobilenet网络结构,由上至下的第二个紫色节点为"MobilenetV2/expanded_conv_14/output"节点,可以看出直接与finetune_layers相接。

2. 加载目标权重,训练新层

# one-hot编码
def to_categorical(data, nums):
    return np.eye(nums)[data]
# 随机生成数据
x_train = np.random.random(size=(141,224,224,3))
y_train = to_categorical(label_fake,10)

# 训练条件配置
## label占位符
y_label = tf.placeholder(tf.int32, (None,10))
## 收集变量作用域finetune_layers内的变量,仅更新添加层的权重
train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="finetune_layers")
## 定义loss
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_label,logits=predictions)
## 定义优化方法,用var_list指定需要更新的权重,此时仅更新train_var权重
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss,var_list=train_var)
## 观察新层权重是否更新 
tf.summary.histogram("mobilenet_conv8",tf.get_default_graph().get_tensor_by_name('MobilenetV2/expanded_conv_8/depthwise/depthwise_weights:0'))
tf.summary.histogram("mobilenet_conv9",tf.get_default_graph().get_tensor_by_name('MobilenetV2/expanded_conv_9/depthwise/depthwise_weights:0'))

## 合并所有summary
merge_all = tf.summary.merge_all()

## 设定迭代次数和批量大学
epochs = 10
batch_size = 16

# 获取指定变量列表var_list的函数
def get_var_list(target_tensor=None):
    '''获取指定变量列表var_list的函数'''
    if target_tensor==None:
        target_tensor = r"MobilenetV2/expanded_conv_14/output:0"
    target = target_tensor.split("/")[1]
    all_list = []
    all_var = []
    # 遍历所有变量,node.name得到变量名称
    # 不使用tf.trainable_variables(),因为batchnorm的moving_mean/variance不属于可训练变量
    for var in tf.global_variables():
        if var != []:
            all_list.append(var.name)
            all_var.append(var)
    try:
        all_list = list(map(lambda x:x.split("/")[1],all_list))
        # 查找对应变量作用域的索引
        ind = all_list[::-1].index(target)
        ind = len(all_list) -  ind - 1
        print(ind)
        del all_list
        return all_var[:ind+1]
    except:
        print("target_tensor is not exist!")

# 目标张量名称,要获取一个需要从文件中加载权重的变量列表var_list
target_tensor = "MobilenetV2/expanded_conv_14/output:0"
var_list = get_var_list(target_tensor)
saver = tf.train.Saver(var_list=var_list)

# 加载文件内的权重,并训练新层
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    writer = tf.summary.FileWriter(r"./logs", sess.graph)
    ## 初始化参数:从文件加载权重 train_var使用初始化函数
    sess.run(tf.variables_initializer(var_list=train_var))
    saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))
    
    for i in range(2000):
        start = (i*batch_size) % x_train.shape[0]
        end = min(start+batch_size, x_train.shape[0])
        _, merge, losses = sess.run([train_step,merge_all,loss],\
                             feed_dict={images:x_train[start:end],\
                                        y_label:y_train[start:end]})
        if i%100==0:
           writer.add_summary(merge, i)

权重初始化注意事项:

  1. 先利用全局初始化tf.global_variables_initializer(),再利用saver.restore顺序不能错,否则加载的权重会被重新初始化 。
sess.run(tf.global_variables_initializer())
saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))
  1. 先利用saver.restore从模型中加载权重,再利用tf.variable_initializaer()初始化指定的var_list,顺序可以调换.
saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))
sess.run(tf.variables_initializer(var_list=train_var))
  1. 前两种方法会对无用的节点也进行变量初始化,并且需要提前进行saver.restore操作,也就是说需要两次save.restore操作,才能保证finetune过程不会报错。现在可以通过筛选出需要从文件中加载权重的所有变量组成var_list,然后定义saver=tf.train.Saver(var_list),选择性的加载变量.

参考:[Tensorflow] 网络局部restore 以及 网络局部训练

3. 最后

欢迎大家扫一扫下面二维码加入微信交流群,如果二维码失效,可以添加博主个人微信,拉你进群

传送门----->Tensorflow系列教程

猜你喜欢

转载自blog.csdn.net/sinat_24143931/article/details/86482374
今日推荐