Tensorflow读取并使用预训练模型:以inception_v3为例

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/AManFromEarth/article/details/79155926

在使用Tensorflow做读取并finetune的时候,发现在读取官方给的inception_v3预训练模型总是出现各种错误,现记录其正确的读取方式和各种错误做法:
关键代码如下:

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3

.....................................................

# 读取网络
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)

....................................................

with tf.Session() as sess:
     # 先初始化所有变量,避免有些变量未读取而产生错误
     init = tf.global_variables_initializer()
     sess.run(init)
     #加载预训练模型
     print('Loading model check point from {:s}'.format(Pretrained_model_dir))

     #这里的exclusions是不需要读取预训练模型中的Logits,因为默认的类别数目是1000,当你的类别数目不是1000的时候,如果还要读取的话,就会报错
     exclusions = ['InceptionV3/Logits',
                   'InceptionV3/AuxLogits']
     #创建一个列表,包含除了exclusions之外所有需要读取的变量
     inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
     #建立一个从预训练模型checkpoint中读取上述列表中的相应变量的参数的函数
     init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)
     #运行该函数
     init_fn(sess)
     print('Loaded.')

其中的…………………………..省略了一些与本文无关的代码。

其中可能会出现的错误如下:
错误1

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [5] rhs shape= [1001]
     [[Node: save_1/Assign_8 = Assign[T=DT_FLOAT, _class=["loc:@InceptionV3/AuxLogits/Conv2d_2b_1x1/biases"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](InceptionV3/AuxLogits/Conv2d_2b_1x1/biases, save_1/RestoreV2_8/_2319)]]

原因:
预训练模型中的类别数class_num=1000,这里输入的class_num=5,当读取完整模型的时候当然会出错。
解决方案:
选择不读取包含类别数的Logits层和AuxLogits层:

exclusions = ['InceptionV3/Logits','InceptionV3/AuxLogits']
inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)

错误2
Tensor name “xxxx” not found in checkpoint files

NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6c/Branch_2/Conv2d_0b_7x1/biases" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
     [[Node: save_1/RestoreV2_180 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_180/tensor_names, save_1/RestoreV2_180/shape_and_slices)]]
     [[Node: save_1/RestoreV2_277/_109 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_854_save_1/RestoreV2_277", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

这里的Tensor name可以是所有inception_v3中变量的名字,出现这种情况的各种原因和解决方案是:
1.创建图的时候没有用arg_scope,是这样创建的:

logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)

解决方案:
在这里加上arg_scope,里面调用的是库中自带的inception_v3_arg_scope

with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)

2.在读取checkpoint的时候未初始化所有变量,即未运行

init = tf.global_variables_initializer()
sess.run(init)

这样会导致有一些checkpoint中不存在的变量未被初始化,比如使用Momentum时的每一层的Momentum参数等。

3.使用slim.assign_from_checkpoint_fn()函数时,没有添加ignore_missing_vars=True属性,由于默认ignore_missing_vars=False,所以,当使用非SGD的optimizer的时候(如Momentum、RMSProp等)时,会提示Momentum或者RMSProp的参数在checkpoint中无法找到,如:
使用Momentum时:

NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6e/Branch_2/Conv2d_0c_1x7/BatchNorm/beta/Momentum" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
     [[Node: save_1/RestoreV2_397 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_397/tensor_names, save_1/RestoreV2_397/shape_and_slices)]]
     [[Node: save_1/RestoreV2_122/_2185 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_2096_save_1/RestoreV2_122", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

使用RMSProp时:

NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6b/Branch_1/Conv2d_0b_1x7/BatchNorm/beta/RMSProp" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
     [[Node: save_1/RestoreV2_257 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_257/tensor_names, save_1/RestoreV2_257/shape_and_slices)]]
     [[Node: save_1/Assign_463/_3950 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_3478_save_1/Assign_463", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

解决方法很简单,就是把ignore_missing_vars=True

init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)

注意:一定要在之前的步骤都完成之后才能设成True,不然如果变量名称全部出错的话,会忽视掉checkpoint中所有的变量,从而不读取任何参数。

以上就是我碰见的问题,希望有所帮助。

猜你喜欢

转载自blog.csdn.net/AManFromEarth/article/details/79155926