版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/baidu_27643275/article/details/83826949
tensorflow hub是一个机器学习模型仓库,我们可以使用已有的模型进行测试;也可以finetune已有模型,从而只有少量数据就可以进行训练,加快训练速度。
以faster_rcnn/openimages_v4/inception_resnet_v2测试为例,来学习tensorflow hub。
Inputs
dtype:tf.float32
shape :[1, height, width, 3]
value:[0.0, 1.0].
Outputs
The output dictionary contains:
detection_boxes: a tf.float32 tensor of shape [N, 4] containing bounding box coordinates in the following order: [ymin, xmin, ymax, ymax].
detection_class_entities: a tf.string tensor of shape [N] containing detection class names as Freebase MIDs.
detection_class_names: a tf.string tensor of shape [N] containing human-readable detection class names.
detection_class_labels: a tf.int64 tensor of shape [N] with class indices.
detection_scores: a tf.float32 tensor of shape [N] containing detection scores.
import cv2
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
with tf.Graph().as_default():
# 输入数据处理
image_file = './test/car.jpg'
img = cv2.imread(image_file)
img = cv2.resize(img, (256, 256))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
input = tf.expand_dims(img, axis=0)
# 构建模型
input_placeholder = tf.placeholder(tf.float32, shape=[1, 256, 256, 3])
detector = hub.Module('https://tensorflow.google.cn/hub/modules/google/faster_rcnn/openimages_v4/inception_resnet_v2/1')
output = detector(input_placeholder, as_dict=True)
detection_class_names = output['detection_class_names']
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
class_names = sess.run(detection_class_names, feed_dict={input_placeholder: input})
print(class_names)
参考:https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1