使用tflite预测数据

二话不说,直接上代码!

import numpy as np
import tensorflow as tf
import cv2 as cv

# Load TFLite model and allocate tensors.
tflite_model = tf.contrib.lite.Interpreter(model_path="/home/zhang/anaconda3/model_pb.tflite")
tflite_model.allocate_tensors()

# Get input and output tensors.
input_details = tflite_model.get_input_details()
output_details = tflite_model.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']

image=cv.imread("image.jpg")

input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)#输入随机数

tflite_model.set_tensor(input_details[0]['index'], input_data)

tflite_model.invoke()
output_data = tflite_model.get_tensor(output_details[0]['index'])
print("out_class")
print(output_data)

我这个tflite文件输入tensor大小是(1,28,28,3),故

input_shape = input_details[0]['shape']

input_shape的大小是(1,28,28,3)。

input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)#输入随机数

输出为(1,28,28,3)的ndarray格式。这句话的意思是虚构一副图像。

但是我们自己在做预测的时候,从输入图像大小是(28,28,3),直接输进函数(如下)运行程序

tflite_model.set_tensor(input_details[0]['index'], input_data)

会报错:ValueError: Failed to set tensor,这样的错误。

原因是差了一个维度,所以,三个维度(28,28,3)如何变成带有batch的四个维度(1,28,28,3)?即增加一个图像数量维度

解决方法如下:

假如你的图像数据类型是

<class 'numpy.ndarray'>

可以赋值让图像赋值给虚构图像,input_data[0]=image,

这样就可以构建一个(1,28,28,3)的数据,就可以通过程序运行。

扫描二维码关注公众号,回复: 2651724 查看本文章

猜你喜欢

转载自blog.csdn.net/weixin_36621501/article/details/81536653