今天参考github上的一个项目完成了视频检测
https://github.com/ayooshkathuria/pytorch-yolo-v3
准备工作
首先把相关文件复制到自己项目的根目录下
还需下载pandas,这里直接用匹配命令
pip install pandas -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
vedio_demo代码修改
parser = argparse.ArgumentParser(description='YOLO v3 Video Detection Module')
#default是视频所在目录,视频格式一定要为avi
parser.add_argument("--video", dest = 'video', help = "Video to run detection upon",default = "video1.avi", type = str)
parser.add_argument("--dataset", dest = "dataset", help = "Dataset on which the network has been trained", default = "pascal")
parser.add_argument("--confidence", dest = "confidence", help = "Object Confidence to filter predictions", default = 0.5)
parser.add_argument("--nms_thresh", dest = "nms_thresh", help = "NMS Threshhold", default = 0.4)
parser.add_argument("--cfg", dest = 'cfgfile', help = "Config file", default="config/yolov3-custom.cfg", type = str)
# parser.add_argument("--weights", dest = 'weightsfile', help = "weightsfile", default="1.weights", type = str)
#此处我自己训练的模型为.pth文件,所以对后面代码需要修改
parser.add_argument("--weights_path", dest = 'weights_path',type=str, default="checkpoints/yolov3_ckpt_3.pth", help="path to weights file")
parser.add_argument("--reso", dest = 'reso', help = "Input resolution of the network. Increase to increase accuracy. Decrease to increase speed", default = "416", type = str)
return parser.parse_args()
从加载模型开始有部分修改
print("Loading network.....")
model = Darknet(args.cfgfile)
if args.weights_path.endswith(".weights"):
# Load darknet weights
model.load_darknet_weights(args.weights_path)
else:
# Load checkpoint weights
model.load_state_dict(torch.load(args.weights_path))
model.eval() # Set in evaluation mode
print("Network successfully loaded")
最后 166行需要修改一下 classes = load_classes(‘data/classes.names’)