图像语义分割——利用DeeplabV3+预测单张照片

当训练好DeeplabV3+模型后,生成了.ckpt文件,下一步希望利用模型进行真实的场景预测,通用的做法生成.pb文件。这样做的好处是:1. 将变量转换为常量,减小模型,2. 便于其他语言调用(.pb文件可直接被C/C++/Java/NodeJS等读取)。

运行 export_model.py 生成模型

利用官方代码文件export_model.py生成 frozen_inference_graph.pb 文件,利用该文件进行预测。这里需要注意的是:必须知道模型的input和output,这可以通过查看代码获得。

python export_model.py \
  --checkpoint_path="./checkpoint_1/model.ckpt-518495" \  # 训练得到的ckpt文件
  --export_path="./output_model/frozen_inference_graph.pb" \  # 需要导出的模型名称
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --num_classes=3 \
  --crop_size=1440 \  # 需要预测图片的大小,如果预测的图像大小比该值大,将报错
  --crop_size=1440 \
  --inference_scales=1.0

源码清楚显示:input_name是'ImageTensor',shape是[1, None, None, 3],数据类型是tf.uint8,你也可以在此处更改数据类型,output_name是 'SemanticPredictions'。知道了input和outp,就可以进行预测了。

# export_model.py部分代码

# Input name of the exported model.
_INPUT_NAME = 'ImageTensor'

# Output name of the exported model.
_OUTPUT_NAME = 'SemanticPredictions'


def _create_input_tensors():
  """Creates and prepares input tensors for DeepLab model.

  This method creates a 4-D uint8 image tensor 'ImageTensor' with shape
  [1, None, None, 3]. The actual input tensor name to use during inference is
  'ImageTensor:0'.
  """
  # input_preprocess takes 4-D image tensor as input.
  input_image = tf.placeholder(tf.uint8, [1, None, None, 3], name=_INPUT_NAME)

预测单张图片

利用生成的.pb文件预测新的图片:

1. 读取图片并转换为uint8,shape为[1, None, None, 3]格式;

2. 读取.pb文件,指明输入和输出;

3.求输出,输出的label为0, 1, 2…,所以看上出全黑;

4. 结果后处理,这一步就因人而异了

import tensorflow as tf
from keras.preprocessing.image import load_img, img_to_array

img = load_img(img_path)  # 输入预测图片的url
img = img_to_array(img)  
img = np.expand_dims(img, axis=0).astype(np.uint8)  # uint8是之前导出模型时定义的

# 加载模型
sess = tf.Session()
with open("frozen_inference_graph.pb", "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    output = tf.import_graph_def(graph_def, input_map={"ImageTensor:0": img},
                                     return_elements=["SemanticPredictions:0"])
    # input_map 就是指明 输入是什么;
    # return_elements 就是指明输出是什么;两者在前面已介绍

result = sess.run(output) 
print(result[0].shape)  # (1, height, width)

结果展示:

我的工况是一个三分类问题,输入图片1040X868,在个人笔记本上,预测比较慢:40s,部署在服务器上,0.4s.

Tips-注意TF版本:

在预测时,经常出现内存溢出的问题,但模型只有157MB,内存为16GB,一直不得解。原Tensorflow是1.8.0,从Github某地方下载,CUDA用9.1版本。后来下载官方Tensorflow1.8.0,CUDA支持9.0版本,不得不重新安装CUDA,内存溢出问题消失。所以:请从正规渠道下载软件。

猜你喜欢

转载自blog.csdn.net/weixin_41713230/article/details/84146087