slim 读取并使用预训练模型 inception_v3 迁移学习

转自:https://blog.csdn.net/amanfromearth/article/details/79155926#commentBox

在使用Tensorflow做读取并finetune的时候,发现在读取官方给的inception_v3预训练模型总是出现各种错误,现记录其正确的读取方式和各种错误做法: 

关键代码如下:

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3
from research.slim.preprocessing import inception_preprocessing
Pretrained_model_dir = '/Users/apple/tensorflow_model/models-master/research/slim/pre_train/inception_v3.ckpt'

image_size = 299

# 读取网络
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    imgPath = 'test.jpg'
    testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()
    testImage = tf.image.decode_jpeg(testImage_string, channels=3)
    processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False)
    processed_images = tf.expand_dims(processed_image, 0)
    logits, end_points = inception_v3.inception_v3(processed_images, num_classes=128, is_training=False)
    w1 = tf.Variable(tf.truncated_normal([128, 5], stddev=tf.sqrt(1/128)))
    b1 = tf.Variable(tf.zeros([5]))
    logits = tf.nn.leaky_relu(tf.matmul(logits, w1) + b1)

with tf.Session() as sess:
     # 先初始化所有变量,避免有些变量未读取而产生错误
     init = tf.global_variables_initializer()
     sess.run(init)
     #加载预训练模型
     print('Loading model check point from {:s}'.format(Pretrained_model_dir))

     #这里的exclusions是不需要读取预训练模型中的Logits,因为默认的类别数目是1000,当你的类别数目不是1000的时候,如果还要读取的话,就会报错
     exclusions = ['InceptionV3/Logits',
                   'InceptionV3/AuxLogits']
     #创建一个列表,包含除了exclusions之外所有需要读取的变量
     inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
     #建立一个从预训练模型checkpoint中读取上述列表中的相应变量的参数的函数
     init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)
     #运行该函数
     init_fn(sess)
     print('Loaded.')
     out = sess.run(logits)
     print(out.shape)
     print(out)


其中可能会出现的错误如下: 
错误1

  • 1
  • 2
  • 3

原因: 
预训练模型中的类别数class_num=1000,这里输入的class_num=5,当读取完整模型的时候当然会出错。 
解决方案: 
选择不读取包含类别数的Logits层和AuxLogits层:

  • 1
  • 2

错误2 
Tensor name “xxxx” not found in checkpoint files 

  • 1
  • 2
  • 3
  • 4

这里的Tensor name可以是所有inception_v3中变量的名字,出现这种情况的各种原因和解决方案是: 
1.创建图的时候没有用arg_scope,是这样创建的:

  • 1

解决方案: 
在这里加上arg_scope,里面调用的是库中自带的inception_v3_arg_scope

  • 1
  • 2

2.在读取checkpoint的时候未初始化所有变量,即未运行

  • 1
  • 2

这样会导致有一些checkpoint中不存在的变量未被初始化,比如使用Momentum时的每一层的Momentum参数等。

3.使用slim.assign_from_checkpoint_fn()函数时,没有添加ignore_missing_vars=True属性,由于默认ignore_missing_vars=False,所以,当使用非SGD的optimizer的时候(如Momentum、RMSProp等)时,会提示Momentum或者RMSProp的参数在checkpoint中无法找到,如: 
使用Momentum时:

  • 1
  • 2
  • 3
  • 4

使用RMSProp时:

  • 1
  • 2
  • 3
  • 4

解决方法很简单,就是把ignore_missing_vars=True

  • 1

注意:一定要在之前的步骤都完成之后才能设成True,不然如果变量名称全部出错的话,会忽视掉checkpoint中所有的变量,从而不读取任何参数。

以上就是我碰见的问题,希望有所帮助。

猜你喜欢

转载自blog.csdn.net/qq_41419675/article/details/80887966