github地址:https://github.com/18150167970/pytorch-yolov3-modifiy
预测完,只要读取模型,然后进行预测就ok了.
1.读模型
# Set up model
model = Darknet(opt.model_config_path)
model.load_weights(opt.weight_path)
model.cuda()
model.eval() #
2.读数据
dataset = Datasets(opt.valid)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=opt.batch_size, shuffle=False)
3.预测并绘图
for batch_i, (img_paths, input_imgs, targets) in enumerate(dataloader):
# Configure input
input_imgs = input_imgs.type(Tensor)
# Get detections
with torch.no_grad():
detections = model(input_imgs)
detections = non_max_suppression(
detections, 80, 0.8, 0.4)
# Log progress
# 绘制预测和真值图像
draw_predict(input_imgs, detections, img_paths)
draw_predict(input_imgs, targets, img_paths, 'gt')
current_time = time.time()
inference_time = datetime.timedelta(seconds=current_time - prev_time)
prev_time = current_time
print ('\t+ Batch %d, Inference Time: %s' % (batch_i, inference_time))
4.绘图(本文图片保存使用有填充的,如果需要没有填充的,代码中也有,具体请看github源码)
读图像
for detection in detections:
if detection is None:
continue
img_path = img_paths[ii]
# 由于是torch.cuda()格式,需要转化为数组
image = imgs[ii].cpu().numpy() * 255
# opencv为GBR,所以要转化通道 ,这里有点bug,所以先存在读取.
image = np.transpose(image, (1, 2, 0))
cv2.imwrite('imges.jpg', image)
image = cv2.imread('imges.jpg')
W, H, _ = image.shape
读bounding box坐标,由于真值和预测的数据格式不同,所以要分别处理.
if types == 'pred':
# yolo预测的为(x1,y1,x2,y2,c,score,label)
bboxs = detection[:, 0:4]
name = detection[:, -1].cpu().numpy()
score = detection[:, 5]
# print detection.shape
x1 = detection[:, 0]
y1 = detection[:, 1]
x2 = detection[:, 2]
y2 = detection[:, 3]
else:
# yolo的真值标签为(label,x,y,w,h) 其中,x,y是bounding box矩阵中心
# 这里求出x1,y1,x2,y2
target = np.zeros((50, 5))
target[:, 0] = detection[:, 0]
target[:, 1] = (detection[:, 1] - detection[:, 3] / 2.0) * W
target[:, 2] = (detection[:, 2] - detection[:, 4] / 2.0) * H
target[:, 3] = (detection[:, 1] + detection[:, 3] / 2.0) * W
target[:, 4] = (detection[:, 2] + detection[:, 4] / 2.0) * H
# 由于真值包含零填充,所以去除掉
detection = target[np.sum(target, axis=1) > 0]
if detection is None:
continue
bboxs = detection[:, 1:5]
name = detection[:, 0]
# score = 1.0
# print detection.shape
x1 = detection[:, 1]
y1 = detection[:, 2]
x2 = detection[:, 3]
y2 = detection[:, 4]
绘图和保存
for i in range(len(name)):
xmin = int(round(float(x1[i])))
ymin = int(round(float(y1[i])))
xmax = int(round(float(x2[i])))
ymax = int(round(float(y2[i])))
# if score[i] <= 0.9:
# continue
if xmax < xmin or ymax < ymin:
continue
# 绘制bounding box
cv2.rectangle(image, (xmin, ymin),
(xmax, ymax), (0, 0, 255), 1)
# 绘制label
cv2.putText(image, classname[int(name[i])], (xmin, ymin - 10),
cv2.FONT_HERSHEY_SIMPLEX, 1e-3 * image.shape[0], (0, 0, 255), 1)
# 绘制标签
# cv2.putText(image, str(score[i])[0:3], (xmin + 30, ymin - 10),
# cv2.FONT_HERSHEY_SIMPLEX, 1e-3 * image.shape[0], (0, 0, 255), 1)
cv2.imwrite(opt.output_path +
img_path[-29:-4] + '_' + types + '.jpg', image)