前言
本篇文章将会带领大家用tornado搭建AI服务,tornado和flask相比,个人更倾向于tornado,详情可以参考如下文章进行了解:
文章1
文章2
一、安装
安装很简单,和flask一样,直接pip即可,前提是你电脑上有python环境
pip install tornado
二、部署服务
这里我们就直接拿YOLOV5为例,官方链接为:YOLOV5,我用的版本可能有点老,但是基本不会影响,只要你下载的模型和你的代码版本对应即可。然后,重新建个py文件(建议)或者将yolov5中的detect.py中的代码全部注释,然后添加以下代码,:
import torch
import cv2,base64
import numpy as np
from utils.general import check_img_size, non_max_suppression,scale_coords
from utils.augmentations import letterbox
import os,time
from pathlib import Path
from utils.plots import Annotator,colors
from models.experimental import attempt_load
import sys,json
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
import tornado.web
import tornado.ioloop
#版面分析
def run(img,img_path,models,save_path,device,model_path,dnn=False):
'''
img: 图片数据
img_path: 图片路径(便于储存结果)
weights: 模型
save_path: 结果保存路径
device: 运行设备
model_path: 模型路径
dnn: 模型加载方式
'''
imgsz=640 #图片大小
#阈值设置
conf_thres=0.25
iou_thres=0.45
old_img=img.copy()
half =device.type != "cpu" #当设备为cuda时,半精度推理
w = str(models[0] if isinstance(models, list) else models) #onnx时会用
model_suffixes=model_path.split(".")[-1] #当前模型的后缀名
stride,names=64,[f'class{
i}'for i in range(1000)]
flag=True #为True表示pt模型,为Falase表示为onnx模型
#用于控制模型后缀,模型后缀不同决定了输入尺寸的不同(onnx为(640,640))
if model_suffixes=="pt":
model=models
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16
flag=False
elif model_suffixes=="onnx":
if dnn:
# check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
else:
# check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
flag=True
imgsz=check_img_size(imgsz,s=stride)
#resize_padding
img=letterbox(img,imgsz,stride,auto=flag)[0]
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
if model_suffixes=="onnx":
img=img.astype("float32")
else:
img=torch.from_numpy(img).to(device)
img=img.half() if half else img.float()
img /=255.0
if len(img.shape)==3:
img=img[None]
if model_suffixes=="pt":
pred=model(img)[0]
elif model_suffixes=="onnx":
if dnn:
net.setInput(img)
pred=torch.tensor(net.forward())
else:
pred=torch.tensor(session.run([session.get_outputs()[0].name],{
session.get_inputs()[0].name:img}))
pred=non_max_suppression(pred,conf_thres,iou_thres,classes=None,agnostic=False,max_det=1000)
for i,det in enumerate(pred):
im0=old_img.copy()
annotator = Annotator(im0, line_width=3, example=str(names))
if len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
for *xyxy,conf,cls in reversed(det):
c=int(cls)
label = f'{
names[c]} {
conf:.2f}'
annotator.box_label(xyxy,label,color=colors(c,True))
im0 = annotator.result()
# cv2.imwrite(save_path+os.sep+Path(img_path).name,im0)
return im0
class DetectImage(object):
def __init__(self,model_path):
'''
:param model_path: 模型路径
'''
self.model_path=model_path
self.device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = attempt_load(self.model_path, map_location=self.device)
self.save_path=root_path+"/runs/layout_results"
os.makedirs(self.save_path,exist_ok=True)
def det(self,img_path,img_data):
res=run(img_data,img_path,self.model,self.save_path,self.device,self.model_path)
return res
def parse_parameters(byte_data):
try:
json_str = byte_data.decode('utf8')
data = json.loads(json_str)
except Exception as e:
print("[parse_parameters]:error:", e)
return None
return data
def base64_to_cvimage(imgdata):
try:
imgstr = base64.b64decode(imgdata)
nparr = np.fromstring(imgstr,np.uint8)
image = cv2.imdecode(nparr,cv2.IMREAD_COLOR)
except Exception as e:
print("[base64_to_cvimage]:",e)
return None
return image
def cvimage_to_base64(image):
try:
image = cv2.imencode('.jpg', image)[1]
image_base64 = str(base64.b64encode(image))[2:-1]
except Exception as e:
return None
return image_base64
class MainHandler(tornado.web.RequestHandler):
def post(self):
s0=time.time()
byte_data = self.request.body
data = parse_parameters(byte_data)
img_base64,imgpath,message=None,None,"success"
if data==None:
message="Data receiving error!"
else:
img_base64 = data["image"]
imgpath=data["imgpath"]
image=base64_to_cvimage(img_base64)
if image.all()==None:
message="base64_to_cvimage error!"
try :
result = detect.det(imgpath,image)
except Exception as e:
message="Detect error!"
image_base64_=cvimage_to_base64(result)
if image_base64_==None:
message="cvimage2base64 error!"
ocr_status={
"imgpath":imgpath,"status":message,"result":image_base64_}
json_res = json.dumps(ocr_status, ensure_ascii=False)
self.write(json_res)
def main(port):
app=tornado.web.Application([(r'/',MainHandler)],)
app.listen(port)
print("server start!")
tornado.ioloop.IOLoop.instance().start()
if __name__ == '__main__':
root_path=os.path.abspath(os.path.join(os.path.dirname(__file__),"."))
detect=DetectImage(r"")#给定模型路径即可
main(3080)
该有的依赖项装好之后,直接运行该文件即可,若出现server start,则说明服务启动成功
三、调用服务
新建test.py,添加如下代码:
import cv2,base64
from pathlib import Path
import json,requests,glob,os
import numpy as np
def img_base64(img_path):
with open(img_path, "rb") as f:
base64_str = base64.b64encode(f.read())
string = bytes.decode(base64_str)
return string
def base64_to_cv(img_base64):
try:
imgstr = base64.b64decode(img_base64)
nparr = np.fromstring(imgstr, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
except Exception as e:
return None
return image
def test_once(path):
img_data = img_base64(path)
data = {
"image": img_data, "imgpath": path}
url = "http://ip:port" #指定ip:port
res = requests.post(url, data=json.dumps(data))
result=eval(res.text)
if result["status"]=="success":
image_base=result["result"]
img=base64_to_cv(image_base)
if img.all()==None:
print("base64tocvimage Error !")
else:
return img
else:
print(result["status"])
if __name__ == '__main__':
img_dir=r"" #图片文件夹
save_dir="./runs/det_res" #存储路径
os.makedirs(save_dir,exist_ok=True)
imglist=glob.glob(img_dir+os.sep+"*.jpg")
for imgpath in imglist:
res=test_once(imgpath)
imgname=Path(imgpath).name
savepath=os.path.join(save_dir,imgname)
cv2.imwrite(savepath,res)
print("Results saved to ",savepath)
注意:在该文件24行处,test_once函数中的url给定你的ip+port,然后运行该文件即可调用服务
总结
以上就是本篇文章的全部内容,部分内容还不完善,需要大家自己添加自己需要的功能,这也是提升自己的一种方法。顺带告诉小伙伴们,不要因为能跑通开源代码而骄傲(如果跑不通,人家挂上去干嘛),这应该是每个算法工程师的基础,真正值得骄傲的是你自己开源的工程被很多人使用并star。试想下你面试时,面试官问你:“你这些东西都是从网上拉下来,然后train一下就好了,那你的实力提现在哪里?”,难道你会说:“我能快速跑通开源代码???”,个人认为应该始终保持谦虚的态度求上进。 以上均有感而发,如有问题,欢迎评论区交流。