tensorflow hub

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/baidu_27643275/article/details/83826949

tensorflow hub是一个机器学习模型仓库,我们可以使用已有的模型进行测试;也可以finetune已有模型,从而只有少量数据就可以进行训练,加快训练速度。

faster_rcnn/openimages_v4/inception_resnet_v2测试为例,来学习tensorflow hub。

Inputs

dtype:tf.float32 
shape :[1, height, width, 3]
value:[0.0, 1.0].

Outputs

The output dictionary contains:

detection_boxes: a tf.float32 tensor of shape [N, 4] containing bounding box coordinates in the following order: [ymin, xmin, ymax, ymax].
detection_class_entities: a tf.string tensor of shape [N] containing detection class names as Freebase MIDs.
detection_class_names: a tf.string tensor of shape [N] containing human-readable detection class names.
detection_class_labels: a tf.int64 tensor of shape [N] with class indices.
detection_scores: a tf.float32 tensor of shape [N] containing detection scores.
import cv2
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np

with tf.Graph().as_default():
    # 输入数据处理
    image_file = './test/car.jpg'
    img = cv2.imread(image_file)
    img = cv2.resize(img, (256, 256))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
    input = tf.expand_dims(img, axis=0)

    # 构建模型
    input_placeholder = tf.placeholder(tf.float32, shape=[1, 256, 256, 3])
    detector = hub.Module('https://tensorflow.google.cn/hub/modules/google/faster_rcnn/openimages_v4/inception_resnet_v2/1')
    output = detector(input_placeholder, as_dict=True)
    detection_class_names = output['detection_class_names']

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())

    class_names = sess.run(detection_class_names, feed_dict={input_placeholder: input})
    print(class_names)

参考:https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1

猜你喜欢

转载自blog.csdn.net/baidu_27643275/article/details/83826949
今日推荐