Pytorch从0开始实现YOLO V3指南 part5——设计输入和输出的流程

本节翻译自:https://blog.paperspace.com/how-to-implement-a-yolo-v3-object-detector-from-scratch-in-pytorch-part-5/

在前一节最后,我们实现了一个将网络输出转换为检测预测的函数。现在我们已经有了一个检测器了,剩下的就是创建输入和输出的流程。

必要条件:

1.此系列教程的Part1到Part4。

2.Pytorch的基本知识,包括如何使用nn.Module,nn.Sequential,torch.nn.parameter类构建常规的结构

3.OpenCV的基础知识

EDIT: 如果你在2018年3月30日之前访问过这篇文章,我们将任意大小的图片调整为Darknet的输入大小的方法就是resize。然而在原始的实现中,调整图像的大小时,需要保持长宽比不变,并填充遗漏的部分。例如,如果我们将1900 x 1280的图像调整为416 x 415,那么调整后的图像应该是这样的。

对于输入处理的差异导致早期实现的性能略低于原始实现。现在这篇文章已经进行了更新,遵循了原始实现中调整大小的方法。

在这一部分中,我们将构建检测器的输入和输出管道。这包括从磁盘读取图像,进行预测,使用预测结果在图像上绘制边界框,然后将它们保存到磁盘。我们还将介绍如何让检测器实时工作在一个摄像机或视频中。我们将介绍一些命令行标志,以允许对网络的各种超参数进行一些实验。那么让我们开始吧!

 注意:这部分需要安装opencv3。

 创建detector.py文件,在顶部添加必要的导入。

from __future__ import division
import time
import torch 
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import cv2 
from util import *
import argparse
import os 
import os.path as osp
from darknet import Darknet
import pickle as pkl
import pandas as pd
import random

创建命令行参数:

因为detector.py是我们要执行来运行检测器的文件,所以最好有可以传递给它的命令行参数。我使用了python的ArgParse模块来实现这一点。

def arg_parse():
    """
    Parse arguements to the detect module
    
    """
    
    parser = argparse.ArgumentParser(description='YOLO v3 Detection Module')
   
    parser.add_argument("--images", dest = 'images', help = 
                        "Image / Directory containing images to perform detection upon",
                        default = "imgs", type = str)
    parser.add_argument("--det", dest = 'det', help = 
                        "Image / Directory to store detections to",
                        default = "det", type = str)
    parser.add_argument("--bs", dest = "bs", help = "Batch size", default = 1)
    parser.add_argument("--confidence", dest = "confidence", help = "Object Confidence to filter predictions", default = 0.5)
    parser.add_argument("--nms_thresh", dest = "nms_thresh", help = "NMS Threshhold", default = 0.4)
    parser.add_argument("--cfg", dest = 'cfgfile', help = 
                        "Config file",
                        default = "cfg/yolov3.cfg", type = str)
    parser.add_argument("--weights", dest = 'weightsfile', help = 
                        "weightsfile",
                        default = "yolov3.weights", type = str)
    parser.add_argument("--reso", dest = 'reso', help = 
                        "Input resolution of the network. Increase to increase accuracy. Decrease to increase speed",
                        default = "416", type = str)
    
    return parser.parse_args()
    
args = arg_parse()
images = args.images
batch_size = int(args.bs)
confidence = float(args.confidence)
nms_thesh = float(args.nms_thresh)
start = 0
CUDA = torch.cuda.is_available()

其中,重要的标志是images(用于指定图像的输入图像或目录)、det(保存检测到的目录)、reso(输入图像的分辨率,可用于速度-精度权衡)、cfg(可更改的配置文件)和weightfile。

加载网络:

这里下载cocoa .names文件,该文件包含COCO数据集中对象的名称。在检测器目录中创建文件夹数据。同样如果你在linux上工作,可以输入。

mkdir data
cd data
wget https://raw.githubusercontent.com/ayooshkathuria/YOLO_v3_tutorial_from_scratch/master/data/coco.name

然后,我们在程序中加载该文件。

num_classes = 80    #For COCO
classes = load_classes("data/coco.names")

load_classes是在util.py中定义的一个函数,它返回一个字典,该字典将每个类的索引映射到它的名称字符串。

def load_classes(namesfile):
    fp = open(namesfile, "r")
    names = fp.read().split("\n")[:-1]
    return names

初始化网络并加载权重。

#Set up the neural network
print("Loading network.....")
model = Darknet(args.cfgfile)
model.load_weights(args.weightsfile)
print("Network successfully loaded")

model.net_info["height"] = args.reso
inp_dim = int(model.net_info["height"])
assert inp_dim % 32 == 0 
assert inp_dim > 32

#If there's a GPU availible, put the model on GPU
if CUDA:
    model.cuda()

#Set the model in evaluation mode
model.eval()

读入输入图片:

从磁盘或目录中读取图像。将图像的路径存储在一个名为imlist的列表中。

read_dir = time.time()
#Detection phase
try:
    imlist = [osp.join(osp.realpath('.'), images, img) for img in os.listdir(images)]
except NotADirectoryError:
    imlist = []
    imlist.append(osp.join(osp.realpath('.'), images))
except FileNotFoundError:
    print ("No file or directory with the name {}".format(images))
    exit()

read_dir是一个用于度量时间的检查点。(大概就是判断每步花了多长时间)

如果保存检测的目录(由det标志定义)不存在,则创建它。

if not os.path.exists(args.det):
    os.makedirs(args.det)

我们将使用OpenCV来加载图像

load_batch = time.time()
loaded_ims = [cv2.imread(x) for x in imlist]

load_batch也是一个时间检查点

OpenCV以numpy数组的形式加载图像,以BGR作为颜色通道的顺序。PyTorch的图像输入格式为(批量x通道x高x宽),通道顺序为RGB。因此,我们在util.py中编写函数prep_image来将numpy数组转换为PyTorch的输入格式。

在编写这个函数之前,我们必须编写一个函数letterbox_image来调整图像的大小,保持长宽比一致,并用(128,128,128)填充未填充的区域

def letterbox_image(img, inp_dim):
    '''resize image with unchanged aspect ratio using padding'''
    img_w, img_h = img.shape[1], img.shape[0]
    w, h = inp_dim
    new_w = int(img_w * min(w/img_w, h/img_h))
    new_h = int(img_h * min(w/img_w, h/img_h))
    resized_image = cv2.resize(img, (new_w,new_h), interpolation = cv2.INTER_CUBIC)
    
    canvas = np.full((inp_dim[1], inp_dim[0], 3), 128)

    canvas[(h-new_h)//2:(h-new_h)//2 + new_h,(w-new_w)//2:(w-new_w)//2 + new_w,  :] = resized_image
    
    return canvas

猜你喜欢

转载自www.cnblogs.com/Thinker-pcw/p/10902042.html