1.数据集标注
ssd训练自己的模型
参考https://blog.csdn.net/u014696921/article/details/53353896
2.用别的模型进行微调,并根据自己的数据类别调整参数
如果仅仅调整程序参数,这时调用预训练模型是会出错的(从新开始训练不会报错),这是因为预训练模型的类别与调整后的c类别不一样,导致某些层输出张量维度不一样,因此出错,修正方法有两种;
方法一:
不加载这些层的参数:
tf.app.flags.DEFINE_string( 'checkpoint_exclude_scopes', 'ssd_300_vgg/block11_box/conv_cls/biases,ssd_300_vgg/block11_box/conv_cls/weights,' 'ssd_300_vgg/block10_box/conv_cls/biases,ssd_300_vgg/block10_box/conv_cls/weights,' 'ssd_300_vgg/block9_box/conv_cls/biases,ssd_300_vgg/block9_box/conv_cls/weights,' 'ssd_300_vgg/block8_box/conv_cls/biases,ssd_300_vgg/block8_box/conv_cls/weights,' 'ssd_300_vgg/block7_box/conv_cls/biases,ssd_300_vgg/block7_box/conv_cls/weights,' 'ssd_300_vgg/block6_box/conv_cls/biases,ssd_300_vgg/block6_box/conv_cls/weights,' 'ssd_300_vgg/block5_box/conv_cls/biases,ssd_300_vgg/block5_box/conv_cls/weights,' 'ssd_300_vgg/block4_box/conv_cls/biases,ssd_300_vgg/block4_box/conv_cls/weights,' 'ssd_300_vgg/block3_box/conv_cls/biases,ssd_300_vgg/block3_box/conv_cls/weights,' 'ssd_300_vgg/block2_box/conv_cls/biases,ssd_300_vgg/block2_box/conv_cls/weights,' 'ssd_300_vgg/block1_box/conv_cls/biases,ssd_300_vgg/block1_box/conv_cls/weights', 'Comma-separated list of scopes of variables to exclude when restoring ' 'from a checkpoint.')
方法二:
修改模型中的参数使其张量维数保持一致:
import os import tensorflow as tf from tensorflow.python import pywrap_tensorflow def readcheckpoint(model_dir="../checkpoints/ssd_300_vgg.ckpt"): # model_dir="../checkpoints/ssd_300_vgg.ckpt" #checkpoint的文件位置 # Read data from checkpoint file reader = pywrap_tensorflow.NewCheckpointReader(model_dir) var_to_shape_map = reader.get_variable_to_shape_map() # Print tensor name and values for key in var_to_shape_map: print("tensor_name: ", key) #输出变量名 # print(reader.get_tensor(key)) #输出变量值 print(reader.get_tensor(key).shape) def savecheckpoint(): ckpt_path="../checkpoints/ssd_300_vgg.ckpt" with tf.Session() as sess: for var_name, _ in tf.contrib.framework.list_variables(ckpt_path): # Load the variable var = tf.contrib.framework.load_variable(ckpt_path, var_name) # Set the new name new_name = var_name print('Renaming %s to %s.' % (var_name, new_name)) # Rename the variable # print(var) if new_name.__contains__('_box/conv_cls/biases'): if new_name.__contains__('block7_box/conv_cls/biases') or new_name.__contains__('block8_box/conv_cls/biases') or new_name.__contains__('block9_box/conv_cls/biases'): var=var[0:15*6] else: var = var[0:15 * 4] if new_name.__contains__('conv_cls/weights'): if new_name.__contains__('block7_box/conv_cls/weights') or new_name.__contains__('block8_box/conv_cls/weights') or new_name.__contains__('block9_box/conv_cls/weights'): var=var[:,:,:,0:15*6] else: var = var[:, :, :, 0:15 * 4] var = tf.Variable(var, name=new_name) # Save the variables saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.save(sess, './test.ckpt') savecheckpoint() # readcheckpoint()