ValueError: Tensor Tensor("mrcnn_detection/PyFunc:0", dtype=float32) is not an element of this graph

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Cyril__Li/article/details/79054596

参考:
1. https://github.com/keras-team/keras/issues/2397

今天在把Mask RCNN改成ROS Server来使用的时候,遇到了这个错误,我是根据这个ROS Subscriber版的Mask RCNN的基础上来改的,原代码运行很正常,但是我在改成ROS Node之后出现了这个错误,看了一晚上才意识到,正如参考[1]里有人提到的:

I had this problem when doing inference in a different thread than where I loaded my model. Here’s how I fixed the problem:

我改完成ROS server之后,load model是在实例化我的调用mask rcnn的类的时候进行的,然而inference是在接收到request的时候才进行,显然不在一个进程里。而那个写成subscriber的版本,他们是在同一个进程里的,subscribe的图片不断的写入一个类成员变量里,这里利用了python多线程中互斥锁确保不会同时读写这个变量,然后就可以让model对当前的图片进行inference了,代码如下:

class MaskRCNNNode(object):
    def __init__(self):
        self._cv_bridge = CvBridge()

        config = InferenceConfig()
        config.display()

        self._visualization = rospy.get_param('~visualization', True)

        # Create model object in inference mode.
        self._model = modellib.MaskRCNN(mode="inference", model_dir="",
                                        config=config)
        # Load weights trained on MS-COCO
        model_path = rospy.get_param('~model_path', COCO_MODEL_PATH)
        # Download COCO trained weights from Releases if needed
        if model_path == COCO_MODEL_PATH and not os.path.exists(COCO_MODEL_PATH):
            utils.download_trained_weights(COCO_MODEL_PATH)

        self._model.load_weights(model_path, by_name=True)

        self._class_names = rospy.get_param('~class_names', CLASS_NAMES)

        self._last_msg = None
        self._msg_lock = threading.Lock()

        self._class_colors = visualize.random_colors(len(CLASS_NAMES))

        self._publish_rate = rospy.get_param('~publish_rate', 100)

    def run(self):
        self._result_pub = rospy.Publisher('~result', Result, queue_size=1)
        vis_pub = rospy.Publisher('~visualization', Image, queue_size=1)
        sub = rospy.Subscriber('~input', Image,
                               self._image_callback, queue_size=1)

        rate = rospy.Rate(self._publish_rate)
        while not rospy.is_shutdown():
            if self._msg_lock.acquire(False):
                msg = self._last_msg
                self._last_msg = None
                self._msg_lock.release()
            else:
                rate.sleep()
                continue

            if msg is not None:
                np_image = self._cv_bridge.imgmsg_to_cv2(msg, 'bgr8')
                # Run detection
                results = self._model.detect([np_image], verbose=0)
                result = results[0]
                result_msg = self._build_result_msg(msg, result)
                self._result_pub.publish(result_msg)

                # Visualize results
                if self._visualization:
                    vis_image = self._visualize(result, np_image)
                    cv_result = np.zeros(shape=vis_image.shape, dtype=np.uint8)
                    cv2.convertScaleAbs(vis_image, cv_result)
                    image_msg = self._cv_bridge.cv2_to_imgmsg(cv_result, 'bgr8')
                    vis_pub.publish(image_msg)

            rate.sleep()

然而我改完之后二者不在同一个进程了,所以按照参考[1]里的方式:

# Right after loading or constructing your model, save the TensorFlow graph:

graph = tf.get_default_graph()

# In the other thread (or perhaps in an asynchronous event handler), do:

global graph
with graph.as_default():
    (... do inference here ...)

对我的代码进行了改动,增加了self.graph这个变量,如下所示:

class MaskRCNNNode(object):
    def __init__(self):
        self._cv_bridge = CvBridge()

        config = InferenceConfig()
        config.display()

        self._visualization = rospy.get_param('~visualization', True)

        # Create model object in inference mode.
        self._model = modellib.MaskRCNN(mode="inference", model_dir="",
                                        config=config)
        # Load weights trained on MS-COCO
        rospack = rospkg.RosPack()
        model_path = rospack.get_path('mask_rcnn_ros')+'/models/mask_rcnn_coco.h5'

        self._model.load_weights(model_path, by_name=True)
        self.graph = tf.get_default_graph()
        self._class_names = rospy.get_param('~class_names', CLASS_NAMES)
        self._class_colors = visualize.random_colors(len(CLASS_NAMES))

        self.vis_pub = rospy.Publisher('~visualization', Image, queue_size=1)
        self.server = rospy.Service('instance_segmentation', InstanceSegmentation, self.handle_instance_segmentation)
        rospy.loginfo("Waiting for request!")

    def handle_instance_segmentation(self, req):
        rospy.loginfo("Request received!")

        np_image = self._cv_bridge.imgmsg_to_cv2(req.color_image, "bgr8")
        # Run detection
        with self.graph.as_default():
            results = self._model.detect([np_image], verbose=0)
        print "got result!"
        result = results[0]
        resp = InstanceSegmentationResponse()
        resp.segmentation_result = self._build_result_msg(req.color_image, result)

        # Visualize results
        if self._visualization:
            vis_image = self._visualize(result, np_image)
            cv_result = np.zeros(shape=vis_image.shape, dtype=np.uint8)
            cv2.convertScaleAbs(vis_image, cv_result)
            image_msg = self._cv_bridge.cv2_to_imgmsg(cv_result, 'bgr8')
            self.vis_pub.publish(image_msg)

        return resp

然后问题就顺利解决了。

猜你喜欢

转载自blog.csdn.net/Cyril__Li/article/details/79054596
今日推荐