nillboy/yolo代码解读2:/demo.py

import sys

sys.path.append('./')

from yolo.net.yolo_tiny_net import YoloTinyNet 
import tensorflow as tf 
import cv2
import numpy as np

classes_name =  ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train","tvmonitor"]


def process_predicts(predicts):
  # predicts (1,7,7,30) 根据网络出来的结果输出box的坐标及所属类别
  # 类别 (1,7,7,20)
  p_classes = predicts[0, :, :, 0:20]
  # 7×7×2个boxes的confidence (1,7,7,2)
  C = predicts[0, :, :, 20:22]
  # 7×7×2个boxes的位置表示 (1,7,7,8)
  coordinate = predicts[0, :, :, 22:]
  p_classes = np.reshape(p_classes, (7, 7, 1, 20))
  C = np.reshape(C, (7, 7, 2, 1))
  # (7,7,2,20)
  P = C * p_classes
  # 平铺P,在7×7×2×20这些数中找到最大的数的索引
  index = np.argmax(P)
  # 将一维数组的下标转换为多维数组的下标
  index = np.unravel_index(index, P.shape)
  # 最大值对应的类别
  class_num = index[3]
  coordinate = np.reshape(coordinate, (7, 7, 2, 4))
  # 最大值对应的box位置表示(4,)
  max_coordinate = coordinate[index[0], index[1], index[2], :]
  # (x,y)是box的中心和网格单元左上角那个点也就是坐标相关联,即在7×7的网格上,这个中心在几又几分之几的位置上
  xcenter = max_coordinate[0]
  ycenter = max_coordinate[1]
  w = max_coordinate[2]
  h = max_coordinate[3]
  # box在原图像上的位置
  xcenter = (index[1] + xcenter) * (448/7.0)
  ycenter = (index[0] + ycenter) * (448/7.0)
  w = w * 448
  h = h * 448
  # box的左上角右下角的坐标
  xmin = xcenter - w/2.0
  ymin = ycenter - h/2.0
  xmax = xmin + w
  ymax = ymin + h
  return xmin, ymin, xmax, ymax, class_num
# 参数
common_params = {'image_size': 448, 'num_classes': 20, 
                'batch_size':1}
net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay':  }
# YoloTinyNet类,test
net = YoloTinyNet(common_params, net_params, test=True)
# 输入图像 (1,448,448,3)
image = tf.placeholder(tf.float32, (1, 448, 448, 3))
# 输出 (1,7,7,30)
predicts = net.inference(image)
# 会话
sess = tf.Session()
# 读取图像,输入数据,resize
np_img = cv2.imread('cat.jpg')
resized_img = cv2.resize(np_img, (448, 448))
# 交换B通道R通道???目的???
np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)
np_img = np_img.astype(np.float32)
# ???为什么乘以2???
np_img = np_img / 255.0 * 2 - 1
np_img = np.reshape(np_img, (1, 448, 448, 3))

saver = tf.train.Saver(net.trainable_collection)
saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt')
np_predict = sess.run(predicts, feed_dict={image: np_img})
# box的坐标及所属类别
xmin, ymin, xmax, ymax, class_num = process_predicts(np_predict)
class_name = classes_name[class_num]
cv2.rectangle(resized_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255))
cv2.putText(resized_img, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255))
cv2.imwrite('cat_out.jpg', resized_img)
sess.close()

猜你喜欢

转载自blog.csdn.net/weixin_38900691/article/details/79586269