使用Flask部署YoloV3-PyTorch

使用Flask部署YoloV3-PyTorch

一、项目简介

这个项目是一个web对象检测的小demo,使用Yolov3(PyTorch) 和 Flask 在 Web 端进行对象检测,涉及目标检测、Flask和Html
Yolov3 来自 Ultralytics,你可以可以使用他们的项目来训练一个满足自己的模型

二. 项目整体框架与代码

项目地址:https://github.com/BonesCat/Yolov3_flask
在这里插入图片描述
主要是在Yolov3-Ultralytics的代码上进行修改,具体如下:

  • 1.将原detect.py修改为detect_for_flask.py,为Flask提供一个接
  • 2.所有上传的文件将被时间重命名并保存到“upload_files”文件夹
  • 3.检测到的图像将被保存到“输出”文件夹中

三、快速开始

  • 按照 ult-yolov3 中requirement要求配置环境,自行安装Flask,注意都需要在一个evn环境中进行安装与配置
  • 下载或训练一个模型,将“.weights/.pt”文件放到weights文件夹,配置正确的cfg,其他配置可以在opt上设置.本项目可以使用原始yolov3提供的官方权重,只需设置对应cfg即可。
  • 启动serve.py,然后在网站上输入“http://127.0.0.1:2222/upload”,上传图片,即可得到结果和检测信息。

四、 核心部分代码与简单讲解

  • Server.py
import time
import os
# 导入flask库中的Flask类与request对象
from flask import Flask, request, flash, redirect, render_template, jsonify
from datetime import timedelta

# 导入模型相关函数
from detect_for_flask import *


app = Flask(__name__)

# 设置上传文件的保存位置
UPLOAD_FOLDER = 'upload_files'
ALLOWED_EXTENSIONS = {
    
    'pdf', 'png', 'jpg', 'jpeg', 'gif'}

# 配置路径到app
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER

# 设置静态文件缓存过期时间
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = timedelta(seconds=5) # timedalte 是datetime中的一个对象,该对象表示两个时间的差值

print("SEND_FILE_MAX_AGE_DEFAULT:", app.config['SEND_FILE_MAX_AGE_DEFAULT'])

# 预先初始化模型
model_inited, opt = init_model()

# 处理文件名的有效性
def allow_filename(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/upload', methods=['GET', 'POST']) # 添加路由

def upload():
    if request.method == 'POST':
        # 如果上传的file不是在files
        if 'file' not in request.files:
            # Flask 消息闪现
            flash('not file part!')
            # 重新显示当前url页面
            return  redirect(request.url)

        '''
        Flask 框架中的 request 对象保存了一次HTTP请求的一切信息。
        files 记录了请求上传的文件
        '''
        f = request.files['file']

        # 处理空文件
        if f.filename == '':
            flash("Nothing file upload")
            return redirect(request.url)

        # 文件非空,且格式满足
        if f and allow_filename(f.filename):
            # 保存上传文件至本地
            # 按照格式获取当前时间,从命名文件
            now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
            file_extension = f.filename.split('.')[-1]
            new_filename = now + '.' + file_extension
            file_path = './' + app.config['UPLOAD_FOLDER'] + '/' + new_filename
            f.save(file_path)

            # 进行预测,并显示图片
            img, obj_infos = detect(model_inited, opt, file_path)
            return render_template('upload_ok.html', det_result = obj_infos)
    return render_template('upload.html')

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=2222)

detect_for_flask.py

import argparse
from sys import platform

from models import *  # set ONNX_EXPORT in models.py
from utils.datasets import *
from utils.utils import *

'''
根据原始YoloV3中的detect.py,重写了检测函数,来适配flask
'''


def init_model():
    '''
    模型参数初始化
    :无输入参数
    :return: 完成初始的模型 和 opt设置
    '''
    # paraments config
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='*.cfg path')
    parser.add_argument('--names', type=str, default='data/coco.names', help='*.names path')
    parser.add_argument('--weights', type=str, default='weights/yolov3.weights', help='weights path')
    parser.add_argument('--output', type=str, default='output', help='output folder')  # detect result will be saved here
    parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS')
    parser.add_argument('--device', default='cpu', help='device id (i.e. 0 or 0,1) or cpu')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    opt = parser.parse_args()
    print(opt)

    # init paraments
    out, weights, save_txt = opt.output, opt.weights, opt.save_txt

    # Initialize
    device = torch_utils.select_device(device='cpu' if ONNX_EXPORT else opt.device)
    if not os.path.exists(out):
        os.makedirs(out)  # make new output folder

    # Initialize model
    model = Darknet(opt.cfg, opt.img_size)

    # Load weights
    attempt_download(weights)
    if weights.endswith('.pt'):  # pytorch format
        model.load_state_dict(torch.load(weights, map_location=device)['model'])
    else:  # darknet format
        load_darknet_weights(model, weights)

    return model, opt

def detect(model, opt, image_path):
    '''
    :param model: 完成初始化的模型
    :param opt: opt参数
    :param image_path:传入的图片地址 
    :param save_img: 是否保存图片
    :return: 完成定位后的结果
    '''
    # Eval mode
    model.to(opt.device).eval()
    # Save img?
    save_img = True

    # Process the upload image

    # read img
    img0 = cv2.imread(image_path)  # BGR
    assert img0 is not None, 'Image Not Found ' + image_path

    # Padded resize
    img = letterbox(img0, new_shape=opt.img_size)[0]

    # Convert
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
    img = np.ascontiguousarray(img)

    # Get names and colors
    names = load_classes(opt.names)
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    # Run inference
    t0 = time.time()

    img = torch.from_numpy(img).to(opt.device)
    img = img.float()  # uint8 to fp16/32
    img /= 255.0  # 0 - 255 to 0.0 - 1.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    with torch.no_grad():
        # Inference
        t1 = torch_utils.time_synchronized()
        pred = model(img)[0]
        t2 = torch_utils.time_synchronized()
        # print("pred:", pred)

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            # 这是检测出来的所有object的,检测结果是一个二维list
            # 每一行存放的是一个obj的左上,右下四个坐标,置信度,类别
            # print("det", det)

            p, s = image_path, ''

            save_path = str(Path(opt.output) / Path(p).name)
            s += '%gx%g ' % img.shape[2:]  # print string
            # 若检测出了对象,则list不为空
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string
                # 设置字典,写入每个目标数据
                obj_info_list = []
                # 遍历二维det中的每行,从而对每一个obj进行处理
                # Write results
                for *xyxy, conf, cls in det:
                    if opt.save_txt:  # Write to file
                        with open(save_path + '.txt', 'a') as file:
                            file.write(('%g ' * 6 + '\n') % (*xyxy, cls, conf))

                    if save_img:  # Add bbox to image
                        label = '%s %.2f' % (names[int(cls)], conf)
                        plot_one_box(xyxy, img0, label=label, color=colors[int(cls)]) # 参数xyxy中包含着bbox的坐标
                    # 记录单个目标的坐标,类别,置信度
                    sig_obj_info =('%s %g %g %g %g %g' ) % (names[int(cls)], *xyxy, conf)
                    print("sig_obj_info:", sig_obj_info)
                    obj_info_list.append(sig_obj_info)

            # Print time (inference + NMS)
            print('%sDone. (%.3fs)' % (s, t2 - t1))


            # Save results (image with detections)
            if save_img:
                # 两次保存
                # 1.永久保存检测结果,存入output文件夹
                cv2.imwrite(save_path, img0)
                # 2.暂存文件,用于显示
                cv2.imwrite('./static/temp.jpg', img0)

    print('Done. (%.3fs)' % (time.time() - t0))
    return img0, obj_info_list


if __name__ == '__main__':
    img_path = './data/samples/timg1.jpg'
    model_inited, opt = init_model()
    result,obj_infos = detect(model = model_inited, opt = opt, image_path=img_path)
    print(obj_infos)

五、项目截图

在这里插入图片描述
在这里插入图片描述

六、 参考与致谢

https://github.com/ultralytics/yolov3
https://blog.csdn.net/rain2211/article/details/105965313/

注:只是简单demo,没有写检测不到时候的处理,自己处理一下报错。

猜你喜欢

转载自blog.csdn.net/wrh975373911/article/details/118419365
今日推荐