TensorFlow:将ckpt文件固化成pb文件

    本文是将yolo3目标检测框架训练出来的ckpt文件固化成pb文件,主要利用了GitHub上的该项目

    为什么要最终生成pb文件呢?简单来说就是直接通过tf.saver保存行程的ckpt文件其变量数据和图是分开的。我们知道TensorFlow是先画图,然后通过placeholde往图里面喂数据。这种解耦形式存在的方法对以后的迁移学习以及对程序进行微小的改动提供了极大的便利性。但是对于训练好,以后不再改变的话这种存在就不再需要。一方面,ckpt文件储存的数据都是变量,既然我们不再改动,就应当让其变成常量,直接‘烧’到图里面。另一方面,对于线上的模型,我们一般是通过C++或者C语言编写的程序进行调用。所以一般模型最终形式都是应该写成pb文件的形式。

     由于这次的程序直接从GitHub上下载后改动较小就能够运行,也就是自己写了很少一部分程序。因此进行调试的时候还出现了以前根本没有注意的一些小问题,同时发现自己对TensorFlow还需要更加详细的去研读。

     首先对程序进行保存的时候,利用 saver = tf.train.Saver(), saver.save(sess,checkpoint_path,global_step=global_step)对训练的数据进行保存,保存格式为ckpt。但是在恢复的时候一直提示有问题,(其恢复语句为:saver = tf.train.Saver(), saver.restore(sess,ckpt_path),其中,ckpt_path是保存ckpt的文件夹路径)。出现问题的原因我估计是因为我是按照每50个epoch进行保存,而不是让其进行固定次数的batch进行保存,这种固定batch次数的保存系统会自动保存最近5次的ckpt文件(该方法的ckpt_path=tf.train,latest_checkpoint('ckpt/')进行回复)。那么如何将利用epoch的次数进行保存呢(这种保存不是近5次的保存,而是每进行一次保存就会留下当时保存的ckpt,而那种按照batch的会在第n次保存,会将n-5次的删除,n>5)。

  我们可以利用:ckpt = tf.train.get_checkpoint_state(ckpt_path),获取最新的ckptpoint文件,然后利用saver.restore(sess,ckpt.checkpoint_path)进行恢复。当然为了安全起见,应该对ckpt和ckpt.checkpoint_path进行判断是否存在后,再进行恢复语句的调用。即:

    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(model_path)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

         另外,想着把程序把每个功能更加细化,即先利用上述程序对模型进行恢复了以后,在saver.restore后进行placeholder以及输出tensor的定义,从而把这些功能写入一个函数进行调用,即:

 input_image_shape = tf.placeholder(dtype = tf.int32, shape = (2,))
    input_image = tf.placeholder(shape = [None, 416, 416, 3], dtype = tf.float32)
    predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
    boxes, scores, classes = predictor.predict(input_image, input_image_shape)

然后利用:

out_boxes, out_scores, out_classes = sess.run(
            [boxes, scores, classes],
            feed_dict={
                input_image: image_data,
                input_image_shape: [image.size[1], image.size[0]]
            })

喂给网络数据,以及输出数据的保留。但是这样网络根本恢复不了。因此还是需要在saver.restore前就把placeholder和输出的tensor定义好。由此也看出来,固化pb文件的方便性。

     对于固化网络,网上有很多的介绍。之所以再介绍,还是由于是用了别人的网络而不是自己的网络遇到的坑。在固化时候我们需要知道输出tensor的名字,而再恢复的时候我们需要知道placeholder的名字。但是,如果网络复杂或者别人的网络命名比较复杂,或者name=,根本就没有自己命名而用的系统自定义的,这样捋起来还是比较费劲的。当时在网上查找的一些方法,像打印整个网络变量的方法(先不管输出的网路名称,甚至随便起一个名字,先固化好pb文件,然后对pb文件进行读取,最后打印变量的名字:

  graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        ['cls_score/cls_score', 'cls_prob']  # We split on comma for convenience
    )
    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print ('开始打印节点名字')
    for op in graph.get_operations():
        print(op.name)
    print("%d ops in the final graph." % len(output_graph_def.node))

这样尽然也能打印出来(尽管输出名字是随便命名的)。但是打印出来的根本对不上,其实不可能对的上,因为打印出来的是变量名(也就是训练的数据),不是输出结果。

     那么怎么办?答案简单的让我也很无语。其实,对ckpt进行数据恢复的时候,直接打印输出的tensor名字就可以。比如说在saver以及placeholder定义的时候:output = model.yolo_inference(images, config.num_anchors / 3, config.num_classes, is_training),我们在后面跟一句:print output,从打印出来的信息即可查看。placeholder的查看方法同样如此。

      对网络进行固化:

 代码:


    input_image_shape = tf.placeholder(dtype = tf.int32, shape = (2,))
    input_image = tf.placeholder(shape = [None, 416, 416, 3], dtype = tf.float32)
    predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
    boxes, scores, classes = predictor.predict(input_image, input_image_shape)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(model_path)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        ['concat_11','concat_12','concat_13']  # We split on comma for convenience
    )
    # # Finally we serialize and dump the output graph to the filesystem
    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())

       由于固化的时候是需要先恢复ckpt网络的,所以还是在restore前写了placeholder和输出tensor的定义(尽管没有用到,但是必须写,不然恢复的时候或报错)。或许可以利用别的方法去读网络,因为每一次保存生成的三个文件有一个是单独保存网络结构的。而这里在restore前对输出tensor进行定义,这样直接调用了已经写好的推理过程。如果只给了四个文件,肯定是能够通过单独保存的网络结构的那个文件对网络进行恢复,然后结合保存变量的那个文件(ckpt),喂给网络数据,即可得到输出tensor。

读取pb文件:

代码:

def pb_detect(image_path, pb_model_path):

    os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_index
    image = Image.open(image_path)
    resize_image = letterbox_image(image, (416, 416))
    image_data = np.array(resize_image, dtype = np.float32)
    image_data /= 255.
    image_data = np.expand_dims(image_data, axis = 0)
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_model_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            input_image_tensor = sess.graph.get_tensor_by_name("Placeholder_1:0")
            input_image_tensor_shape = sess.graph.get_tensor_by_name("Placeholder:0")
            # 定义输出的张量名称
            #output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
            boxes = sess.graph.get_tensor_by_name("concat_11:0")
            scores = sess.graph.get_tensor_by_name("concat_12:0")
            classes = sess.graph.get_tensor_by_name("concat_13:0")
            # 读取测试图片
            # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字(需要在名字后面加:0),不是操作节点的名字
            out_boxes, out_scores, out_classes= sess.run([boxes,scores,classes],
                           feed_dict={
                               input_image_tensor: image_data,
                               input_image_tensor_shape: [image.size[1], image.size[0]]
            })

可以看到读取pb文件只需要比恢复ckpt文件容易的多,直接将placeholder的名字获取到,将数据输入恢复的网络,以及读取输出即可。

猜你喜欢

转载自blog.csdn.net/qq26983255/article/details/82846614