通过类来实现多session 运行

#xilerihua
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys

#objectlocation
import six.moves.urllib as urllib
import tarfile
import matplotlib
matplotlib.use('Agg')
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util
import time

class multi():
    """初始化所有模型"""
    def __init__(self):
        #  加载faster_rcnn 计算图
        self.faster_graph = tf.Graph()
        with self.faster_graph.as_default():
            self.od_graph_def2 = tf.GraphDef()
            with tf.gfile.GFile(r'E:/Project/TaoBaoLocation_new/research/object_detection/faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb', 'rb') as fid:
                self.serialized_graph = fid.read()
                self.od_graph_def2.ParseFromString(self.serialized_graph)
                tf.import_graph_def(self.od_graph_def2, name='')
        self.faster_sess = tf.Session(graph=self.faster_graph)

        # 加载inception_v3计算图
        self.inception_graph = tf.Graph()
        with self.inception_graph.as_default():
            self.od_graph_def2 = tf.GraphDef()
            with tf.gfile.GFile(r'E:/Project/XiLeRiHuaReg/inception_v3_model/output_graph.pb', 'rb') as fid:
                self.serialized_graph = fid.read()
                self.od_graph_def2.ParseFromString(self.serialized_graph)
                tf.import_graph_def(self.od_graph_def2, name='')
        self.inception_sess = tf.Session(graph=self.inception_graph)


    def get_result(self, type, image_path):
        if type == '2':
            #xilerihua
            lines = tf.gfile.GFile('E:/Project/XiLeRiHuaReg/inception_v3_model/output_labels.txt').readlines()
            uid_to_human = {}
            for uid, line in enumerate(lines):
                line = line.strip('\n')
                uid_to_human[uid] = line

            def id_to_string(node_id):
                if node_id not in uid_to_human:
                    return ''
                return uid_to_human[node_id]

            softmax_tensor = self.inception_sess.graph.get_tensor_by_name('final_result:0')

            image_data = tf.gfile.GFile(image_path, 'rb').read()
            predictions = self.inception_sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
            predictions = np.squeeze(predictions)

            # image_path = os.path.join(sys.argv[2])

            top_k = predictions.argsort()[::-1][:1]  # 取前k个,此处取最相似的那个

            for node_id in top_k:  # 只取第一个
                human_string = id_to_string(node_id)
                score = predictions[node_id]

            human_kanji = {
                'baby wipes': '婴儿湿巾',
                'bath towel': '洗澡巾',
                'convenient toothpick box': '便捷牙具盒',
                'dish rack': '沥水架',
                'hooks4': '挂钩粘钩4个装',
                'kitchen towel': '厨房方巾',
                'towel': '毛巾',
                'macaron basin': '马卡龙家用多用盆',
                'multi functional dental box': '多功能牙具盒',
                'paring knife': '削皮刀',
                'pineapple towel set': '菠萝纹毛巾浴巾套装',
                'rubbish bag': '垃圾袋',
                'sponge': '清洁海绵',
                'stainless hook': '不锈钢多用挂钩',
                'storage boxes': '三格储物盒',
                'towel set': '毛巾浴巾套装',
                'usb cable': '数据线',
                'liquor': '劲酒'
            }
            thres = 0.6
            if score < thres:
                print('不在17个范围之内')
            elif human_kanji[human_string] == '劲酒':
                print('不在17个范围之内')
            else:
                print(human_kanji[human_string])

        if type == '1':

            # List of the strings that is used to add correct label for each box.
            PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

            NUM_CLASSES = 90

            ##################### Loading label map
            # print('Loading label map...')
            label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
            categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
                                                                        use_display_name=True)
            category_index = label_map_util.create_category_index(categories)

            ##################### Helper code
            def load_image_into_numpy_array(image):
                (im_width, im_height) = image.size
                return np.array(image.getdata()).reshape(
                    (im_height, im_width, 3)).astype(np.uint8)

            ##################### Detection
            # 测试图片的路径,可以根据自己的实际情况修改
            # TEST_IMAGE_PATH = 'test_images/image1.jpg'
            TEST_IMAGE_PATH = image_path
            # Size, in inches, of the output images.
            IMAGE_SIZE = (12, 8)

            # with tf.Session(graph=self.faster_graph) as self.faster_sess:
                # print(TEST_IMAGE_PATH)
            image = Image.open(TEST_IMAGE_PATH)
            image_np = load_image_into_numpy_array(image)
            image_np_expanded = np.expand_dims(image_np, axis=0)
            image_tensor = self.faster_graph.get_tensor_by_name('image_tensor:0')
            boxes = self.faster_graph.get_tensor_by_name('detection_boxes:0')
            scores = self.faster_graph.get_tensor_by_name('detection_scores:0')
            classes = self.faster_graph.get_tensor_by_name('detection_classes:0')
            num_detections = self.faster_graph.get_tensor_by_name('num_detections:0')

            # Actual detection.
            (boxes, scores, classes, num_detections) = self.faster_sess.run(
                [boxes, scores, classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})

            scores = np.squeeze(scores)
            scores = scores.reshape((100, 1))
            boxes = np.squeeze(boxes)
            res = np.hstack((boxes, scores))

            # 筛选>thres的box
            thres = 0.55
            reserve_boxes_0 = []
            for b in res:
                if b[-1]>thres:
                    reserve_boxes_0.append(b.tolist())

            # print('reserve_boxes_0:',reserve_boxes_0)

            #转换坐标
            reserve_boxes=[]
            w = image_np.shape[1]  #  1,3乘 1024
            h = image_np.shape[0]  #  0,2乘 636
            # print('h:',h,'w:',w)

            for box in reserve_boxes_0:
                # print([int(float(box[0]*h)),int(float(box[2]*h)),int(float(box[1]*w)),int(float(box[3]*w))],'tran')
                # reserve_boxes.append([int(float(box[0]*h)),int(float(box[2]*h)),int(float(box[1]*w)),int(float(box[3]*w))])
                reserve_boxes.append([int(float(box[1]*w)),int(float(box[0]*h)),int(float(box[3]*w)),int(float(box[2]*h))])

            # print('reserve_boxes:',reserve_boxes)

            #没有找到一个框的情况
            if len(reserve_boxes)==0:#为0的情况,裁剪返回图片坐标
                w_subtract = int(image_np.shape[1] / 10)
                h_subtract = int(image_np.shape[0] / 10)
                print(w_subtract, h_subtract, image_np.shape[1] - w_subtract, image_np.shape[0] - h_subtract)
            else:
                # 保留最靠近中间的那个框的情况
                # print('w:',image_np.shape[1],'h:',image_np.shape[0])
                # 1.计算图片的中心点
                # y:im.shape[0],x:im.shape[1]
                x_center, y_center = image_np.shape[1] / 2, image_np.shape[0] / 2
                # print(x_center,y_center)


                # 2 计算找出来的框到中心点的距离
                dis_l = []
                for b in reserve_boxes:
                    b_xcenter, b_ycenter = int((b[0] + b[2]) / 2), int((b[1] + b[3]) / 2)
                    distance = np.sqrt((x_center - b_xcenter) ** 2 + (y_center - b_ycenter) ** 2)
                    dis_l.append(distance)
                    # print('b_xcenter,b_ycenter:',b_xcenter,b_ycenter,distance)

                # 拿到最靠中心的box的index
                center_index = dis_l.index(min(dis_l))
                det = reserve_boxes[center_index]
                print(det[0],det[1],det[2],det[3])

                #可视化1
                # cv2.rectangle(image_np, (det[0], det[1]), (det[2], det[3]), thickness=2, color=(0, 0, 255))
                # cv2.imshow('res',image_np)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()




#初始化
multi = multi()

for i in range(5):
    start_t=time.time()
    multi.get_result("1","1.jpg")
    end_t=time.time()
    print('t1:',end_t-start_t)
    multi.get_result("2","1.jpg")
    start_t3=time.time()
    print('t2:',start_t3-end_t)

  

猜你喜欢

转载自www.cnblogs.com/liutianrui1/p/10914121.html
今日推荐