pointer detection SDK

pointer detection SDK

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from keras.models import model_from_json
import tensorflow as tf
import cv2
import numpy as np
import keras.backend as K

class Pointer:
    def __init__(self, gpu_id, gpu_memory_fraction):
        self.gpu_id = gpu_id
        self.gpu_memory_fraction = gpu_memory_fraction
        pass

    def load_model(self, model_dir, model_file, weights_dir):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_fraction

        K.tensorflow_backend.set_session(tf.Session(config=config))

        with tf.device('/GPU:{}'.format(self.gpu_id)):
            with open(os.path.join(model_dir, model_file)) as f:
                model = model_from_json(f.read())

            model.load_weights(weights_dir)
        return model

    def compare(self, model, input, compare_value, gpu_id):
        '''

        compare the output value of keras model and given value

        keras output a list of 3 values: [[predicted degree of pointer,
                                          predicted x coordinate of pointer,
                                          predicted y coordinate of pointer]]
        :param model: keras model
        :param input: input image of shape (?, ?, 3)
        :param compare_value: given value
        :param gpu_id: chosen GPU
        :return:
        '''

        channels = input.shape[-1]
        if channels != 3:
            raise('The input needs to be an RGB image')
        resized = cv2.resize(input, (224, 224))
        input_tensor = np.expand_dims(resized, 0)
        # Now the input tensor has the shape of (1, 224, 224, 3)
        with tf.device('/gpu:{}'.format(gpu_id)):
            output = model.predict(input_tensor)

        with tf.device('/cpu:0'):
            degree = output[:,0]
            # convert degree to a range of [0, 180]
            if degree < 0:
                degree += 180.0

            if degree < compare_value:
                return True
            else:
                return False

猜你喜欢

转载自blog.csdn.net/lmwang1234/article/details/83500984
SDK