pytorch物体检测(视频检测)

今天参考github上的一个项目完成了视频检测
https://github.com/ayooshkathuria/pytorch-yolo-v3

准备工作

首先把相关文件复制到自己项目的根目录下
还需下载pandas,这里直接用匹配命令
pip install pandas -i http://pypi.douban.com/simple --trusted-host pypi.douban.com

vedio_demo代码修改

 parser = argparse.ArgumentParser(description='YOLO v3 Video Detection Module')
 #default是视频所在目录,视频格式一定要为avi
    parser.add_argument("--video", dest = 'video', help = "Video to run detection upon",default = "video1.avi", type = str)
    parser.add_argument("--dataset", dest = "dataset", help = "Dataset on which the network has been trained", default = "pascal")
    parser.add_argument("--confidence", dest = "confidence", help = "Object Confidence to filter predictions", default = 0.5)
    parser.add_argument("--nms_thresh", dest = "nms_thresh", help = "NMS Threshhold", default = 0.4)
    parser.add_argument("--cfg", dest = 'cfgfile', help = "Config file", default="config/yolov3-custom.cfg",  type = str)
    # parser.add_argument("--weights", dest = 'weightsfile', help = "weightsfile", default="1.weights", type = str)
    #此处我自己训练的模型为.pth文件,所以对后面代码需要修改
    parser.add_argument("--weights_path", dest = 'weights_path',type=str, default="checkpoints/yolov3_ckpt_3.pth", help="path to weights file")
    parser.add_argument("--reso", dest = 'reso', help = "Input resolution of the network. Increase to increase accuracy. Decrease to increase speed", default = "416", type = str)
    return parser.parse_args()

从加载模型开始有部分修改

 print("Loading network.....")
    model = Darknet(args.cfgfile)
    if args.weights_path.endswith(".weights"):
        # Load darknet weights
        model.load_darknet_weights(args.weights_path)
    else:
        # Load checkpoint weights
        model.load_state_dict(torch.load(args.weights_path))

    model.eval()  # Set in evaluation mode
    print("Network successfully loaded")

最后 166行需要修改一下 classes = load_classes(‘data/classes.names’)

发布了8 篇原创文章 · 获赞 3 · 访问量 391

猜你喜欢

转载自blog.csdn.net/weixin_43570254/article/details/104210782