基于web端的深度学习模型部署

1.1 web服务与技术框架

     下面以ResNet50预训练模型为例,旨在展示一个轻量级的深度学习模型部署,写一个较为简单的图像分类的REST API。主要技术框架为Keras+Flask+Redis。其中Keras作为模型框架、Flask作为后端Web框架、Redis则是方便以键值形式存储图像的数据库。各主要package版本:

tensorflow 1.14

keras 2.2.4

flask 1.1.1

redis 3.3.8

     先简单说一下Web服务,一个Web应用的本质无非就是客户端发送一个HTTP请求,然后服务器收到请求后生成一个HTML文档作为响应返回给客户端的过程。在部署深度学习模型时,大多时候我们不需要搞一个前端页面出来,一般是以REST API的形式提供给开发调用。那么什么是API呢?很简单,如果一个URL返回的不是HTML,而是机器能直接解析的数据,这样的一个URL就可以看作是一个API。

先开启Redis服务:

redis-server

1.2 服务配置

     定义一些配置参数:

IMAGE_WIDTH = 224

IMAGE_HEIGHT = 224

IMAGE_CHANS = 3

IMAGE_DTYPE = "float32"

IMAGE_QUEUE = "image_queue"

BATCH_SIZE = 32

SERVER_SLEEP = 0.25

CLIENT_SLEEP = 0.25

     指定输入图像大小、类型、batch_size大小以及Redis图像队列名称。

     然后创建Flask对象实例,建立Redis数据库连接:

app = flask.Flask(__name__)

db = redis.StrictRedis(host="localhost", port=6379, db=0)

model = None

     因为图像数据作为numpy数组不能直接存储到Redis中,所以图像存入到数据库之前需要将其序列化编码,从数据库取出时再将其反序列化解码即可。分别定义编码和解码函数:

def base64_encode_image(img):

    return base64.b64encode(img).decode("utf-8")


def base64_decode_image(img, dtype, shape):

    if sys.version_info.major == 3:

        img = bytes(img, encoding="utf-8")

    img = np.frombuffer(base64.decodebytes(img), dtype=dtype)

    img = img.reshape(shape)

    return img

另外待预测图像还需要进行简单的预处理,定义预处理函数如下:

def prepare_image(image, target):

# if the image mode is not RGB, convert it

if image.mode != "RGB":

image = image.convert("RGB")

# resize the input image and preprocess it

image = image.resize(target)

image = img_to_array(image)

# expand image as one batch like shape (1, c, w, h)

image = np.expand_dims(image, axis=0)

image = imagenet_utils.preprocess_input(image)

# return the processed image

return image

1.3 预测接口定义

    准备工作完毕之后,接下来就是主要的两大部分:模型预测部分和app后端响应部分。先定义模型预测函数如下:

def classify_process():
    # 导入模型
    print("* Loading model...")
    model = ResNet50(weights="imagenet")
    print("* Model loaded")
    while True:
        # 从数据库中创建预测图像队列
        queue = db.lrange(IMAGE_QUEUE, 0, BATCH_SIZE - 1)
        imageIDs = []
        batch = None
        # 遍历队列
        for q in queue:
            # 获取队列中的图像并反序列化解码
            q = json.loads(q.decode("utf-8"))
            image = base64_decode_image(q["image"], IMAGE_DTYPE,
                                        (1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANS))
            # 检查batch列表是否为空
            if batch is None:
                batch = image
            # 合并batch
            else:
                batch = np.vstack([batch, image])
            # 更新图像ID
            imageIDs.append(q["id"])
         if len(imageIDs) > 0:
            print("* Batch size: {}".format(batch.shape))
            preds = model.predict(batch)
            results = imagenet_utils.decode_predictions(preds)
            # 遍历图像ID和预测结果并打印
            for (imageID, resultSet) in zip(imageIDs, results):
                # initialize the list of output predictions
                output = []
                # loop over the results and add them to the list of
                # output predictions
                for (imagenetID, label, prob) in resultSet:
                    r = {"label": label, "probability": float(prob)}
                    output.append(r)
                # 保存结果到数据库
                db.set(imageID, json.dumps(output))
            # 从队列中删除已预测过的图像
            db.ltrim(IMAGE_QUEUE, len(imageIDs), -1)
        time.sleep(SERVER_SLEEP)

然后定义app服务:

@app.route("/predict", methods=["POST"])
def predict():
    # 初始化数据字典
    data = {"success": False}
    # 确保图像上传方式正确
    if flask.request.method == "POST":
        if flask.request.files.get("image"):
            # 读取图像数据
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image))
            image = prepare_image(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
            # 将数组以C语言存储顺序存储
            image = image.copy(order="C")
            # 生成图像ID
            k = str(uuid.uuid4())
            d = {"id": k, "image": base64_encode_image(image)}
            db.rpush(IMAGE_QUEUE, json.dumps(d))
            # 运行服务
            while True:
                # 获取输出结果
                output = db.get(k)
                if output is not None:
                    output = output.decode("utf-8")
                    data["predictions"] = json.loads(output)
                    db.delete(k)
                    break
                time.sleep(CLIENT_SLEEP)
            data["success"] = True
        return flask.jsonify(data)

  Flask使用Python装饰器在内部自动将请求的URL和目标函数关联了起来,这样方便我们快速搭建一个Web服务。

1.4 接口测试

     服务搭建好了之后我们可以用一张图片来测试一下效果:

curl -X POST -F [email protected] 'http://127.0.0.1:5000/predict'

模型端的返回:

预测结果返回:

  最后我们可以给搭建好的服务进行一个压力测试,看看服务的并发等性能如何,定义一个压测文件stress_test.py 如下:

from threading import Thread
import requests
import time
# 请求的URL
KERAS_REST_API_URL = "http://127.0.0.1:5000/predict"
# 测试图片
IMAGE_PATH = "test.jpg"
# 并发数
NUM_REQUESTS = 500
# 请求间隔
SLEEP_COUNT = 0.05
def call_predict_endpoint(n):
    # 上传图像
    image = open(IMAGE_PATH, "rb").read()
    payload = {"image": image}
    # 提交请求
    r = requests.post(KERAS_REST_API_URL, files=payload).json()
    # 确认请求是否成功
    if r["success"]:
        print("[INFO] thread {} OK".format(n))
    else:
        print("[INFO] thread {} FAILED".format(n))
# 多线程进行
for i in range(0, NUM_REQUESTS):
    # 创建线程来调用api
    t = Thread(target=call_predict_endpoint, args=(i,))
    t.daemon = True
    t.start()
    time.sleep(SLEEP_COUNT)
time.sleep(300)

测试效果如下:

猜你喜欢

转载自blog.csdn.net/wzhrsh/article/details/109550308