使用camera在tensorflow/slim下调用pb文件进行图像识别的预测

建立demo_cam.py文件,python代码如下:
代码中的camera使用的是realsenseD435i

import tensorflow as tf
import numpy as np
import cv2
from datasets import dataset_utils
#from IPython import display
#import pylab
#import PIL
import time
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import matplotlib.font_manager as fm
import pyrealsense2 as rs

pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
profile = pipeline.start(config)
align_to = rs.stream.color
align = rs.align(align_to)


#image_dir='./data/flower_photos/daisy/5547758_eea9edfd54_n.jpg'
dataset_dir='./data/flower_photos'
model_dir ='./output_model_pb/frozen_graph.pb'


def get_aligned_images():
    frames = pipeline.wait_for_frames()
    aligned_frames = align.process(frames)
    aligned_depth_frame = aligned_frames.get_depth_frame()
    color_frame = aligned_frames.get_color_frame()
    depth_image = np.asanyarray(aligned_depth_frame.get_data())
    depth_image_8bit = cv2.convertScaleAbs(depth_image, alpha=0.03)
    depth_org = depth_image_8bit
    depth_image_8bit = 255 - depth_image_8bit
    pos=np.where(depth_image_8bit==255)
    depth_image_8bit[pos]=0
    depth_medianBlur = cv2.medianBlur(depth_image_8bit, 5)  # 中值滤波
    depth_max = np.max(depth_medianBlur)
    #print(depth_max)
    color_image = np.asanyarray(color_frame.get_data())
    depth_image_3d = np.dstack((depth_image_8bit,depth_image_8bit,depth_image_8bit)) #depth image is 1 channel, color is 3 channels
    depth_image_3d_org = np.dstack((depth_org, depth_org, depth_org))
    #视差图
    depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET)
    return color_image, depth_medianBlur, depth_image_3d_org


#opencv
class TOD(object):
  def __init__(self):
    self.PATH_TO_CKPT = './output_model_pb/frozen_graph.pb'
    self.NUM_CLASSES = 5
    self.detection_graph = tf.Graph()
    self.label_map = dataset_utils.read_label_file(dataset_dir)
    with self.detection_graph.as_default():
      od_graph_def = tf.GraphDef()
      with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
    #return detection_graph
    with self.detection_graph.as_default():
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(graph=self.detection_graph, config=config)
        self.windowNotSet = True

  def visualization(self, image, str):
      image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
      draw = ImageDraw.Draw(image_pil)
      font = ImageFont.truetype(fm.findfont(fm.FontProperties(family='DejaVu Sans')), 15)  # 设置字体DejaVu Sans
      draw.text((10, 10), str, 'red', font)  # 'fuchsia'
      np.copyto(image, np.array(image_pil))
      return image

  def classify(self,image,resized):
    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
    image_np_expanded = np.expand_dims(resized, axis=0)
    inp = self.detection_graph.get_tensor_by_name('input:0')
    #predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Predictions/Reshape_1:0')
    predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0')
    start_time = time.time()
    pred = self.sess.run(
        predictions,
        feed_dict={
    
    inp: image_np_expanded})
    elapsed_time = time.time() - start_time
    #print(pred)
    print('inference time cost: {}'.format(elapsed_time))
    font1 = str(self.label_map[pred.argmax()])
    font2 = str(pred.max())
    font3 = font1 + ":" + font2
    img = self.visualization(image,font3)
    #return pred
    #print("Top 1 Prediction: ", x.argmax(), self.label_map[x.argmax()], x.max())
    cv2.namedWindow("classification", cv2.WINDOW_NORMAL)
    cv2.imshow("classification", img)



if __name__ == '__main__':
  width = 299
  height = 299
  dim = (width, height)
  # resize image to [-1,1] Maps pixel values to the range [-1, 1]
  classifier = TOD()
  while 1:
      rgb, depth, depcol = get_aligned_images()
      #image = cv2.imread(image_dir)
      image = rgb
      resized = (cv2.resize(image, dim)).astype(np.float) / 128 - 1
      classifier.classify(image,resized)
      k = cv2.waitKey(1) & 0xff
      if k == ord('q') or k == 27:
          pipeline.stop()
          break
  cv2.destroyAllWindows()

其中用到的labels.txt文件的格式为:

0:daisy
1:dandelion
2:roses
3:sunflowers
4:tulips

运行

python demo_cam.py

猜你喜欢

转载自blog.csdn.net/gaoqing_dream163/article/details/115205714