import sys
sys.path.append('./')
from yolo.net.yolo_tiny_net import YoloTinyNet
import tensorflow as tf
import cv2
import numpy as np
classes_name = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train","tvmonitor"]
def process_predicts(predicts):
# predicts (1,7,7,30) 根据网络出来的结果输出box的坐标及所属类别
# 类别 (1,7,7,20)
p_classes = predicts[0, :, :, 0:20]
# 7×7×2个boxes的confidence (1,7,7,2)
C = predicts[0, :, :, 20:22]
# 7×7×2个boxes的位置表示 (1,7,7,8)
coordinate = predicts[0, :, :, 22:]
p_classes = np.reshape(p_classes, (7, 7, 1, 20))
C = np.reshape(C, (7, 7, 2, 1))
# (7,7,2,20)
P = C * p_classes
# 平铺P,在7×7×2×20这些数中找到最大的数的索引
index = np.argmax(P)
# 将一维数组的下标转换为多维数组的下标
index = np.unravel_index(index, P.shape)
# 最大值对应的类别
class_num = index[3]
coordinate = np.reshape(coordinate, (7, 7, 2, 4))
# 最大值对应的box位置表示(4,)
max_coordinate = coordinate[index[0], index[1], index[2], :]
# (x,y)是box的中心和网格单元左上角那个点也就是坐标相关联,即在7×7的网格上,这个中心在几又几分之几的位置上
xcenter = max_coordinate[0]
ycenter = max_coordinate[1]
w = max_coordinate[2]
h = max_coordinate[3]
# box在原图像上的位置
xcenter = (index[1] + xcenter) * (448/7.0)
ycenter = (index[0] + ycenter) * (448/7.0)
w = w * 448
h = h * 448
# box的左上角右下角的坐标
xmin = xcenter - w/2.0
ymin = ycenter - h/2.0
xmax = xmin + w
ymax = ymin + h
return xmin, ymin, xmax, ymax, class_num
# 参数
common_params = {'image_size': 448, 'num_classes': 20,
'batch_size':1}
net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay': }
# YoloTinyNet类,test
net = YoloTinyNet(common_params, net_params, test=True)
# 输入图像 (1,448,448,3)
image = tf.placeholder(tf.float32, (1, 448, 448, 3))
# 输出 (1,7,7,30)
predicts = net.inference(image)
# 会话
sess = tf.Session()
# 读取图像,输入数据,resize
np_img = cv2.imread('cat.jpg')
resized_img = cv2.resize(np_img, (448, 448))
# 交换B通道R通道???目的???
np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)
np_img = np_img.astype(np.float32)
# ???为什么乘以2???
np_img = np_img / 255.0 * 2 - 1
np_img = np.reshape(np_img, (1, 448, 448, 3))
saver = tf.train.Saver(net.trainable_collection)
saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt')
np_predict = sess.run(predicts, feed_dict={image: np_img})
# box的坐标及所属类别
xmin, ymin, xmax, ymax, class_num = process_predicts(np_predict)
class_name = classes_name[class_num]
cv2.rectangle(resized_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255))
cv2.putText(resized_img, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255))
cv2.imwrite('cat_out.jpg', resized_img)
sess.close()
nillboy/yolo代码解读2:/demo.py
猜你喜欢
转载自blog.csdn.net/weixin_38900691/article/details/79586269
今日推荐
周排行