Web-based deep learning model deployment

1.1 web services and technical framework

     Let's take the ResNet50 pre-training model as an example to show a lightweight deep learning model deployment and write a simpler image classification REST API. The main technical framework is Keras+Flask+Redis . Among them, Keras is the model framework, Flask is the back-end web framework, and Redis is the database for storing images in key-value form. Major package versions:

tensorflow 1.14

keras 2.2.4

flask 1.1.1

redis 3.3.8

     Let's briefly talk about Web services. The essence of a Web application is nothing more than the process in which the client sends an HTTP request, and then the server generates an HTML document as a response and returns it to the client after receiving the request. When deploying a deep learning model, most of the time we do not need to create a front-end page, which is generally provided for development calls in the form of a REST API. So what is an API? Very simple, if a URL returns not HTML, but data that can be directly parsed by the machine, such a URL can be regarded as an API.

Start the Redis service first:

redis-server

1.2 Service configuration

     Define some configuration parameters:

IMAGE_WIDTH = 224

IMAGE_HEIGHT = 224

IMAGE_CHANS = 3

IMAGE_DTYPE = "float32"

IMAGE_QUEUE = "image_queue"

BATCH_SIZE = 32

SERVER_SLEEP = 0.25

CLIENT_SLEEP = 0.25

     Specify the input image size, type, batch_size size, and Redis image queue name.

     Then create an instance of the Flask object and establish a Redis database connection:

app = flask.Flask(__name__)

db = redis.StrictRedis(host="localhost", port=6379, db=0)

model = None

     Because image data cannot be directly stored in Redis as a numpy array, the image needs to be serialized and encoded before being stored in the database, and then deserialized and decoded when it is retrieved from the database. Define the encoding and decoding functions separately:

def base64_encode_image(img):

    return base64.b64encode(img).decode("utf-8")


def base64_decode_image(img, dtype, shape):

    if sys.version_info.major == 3:

        img = bytes(img, encoding="utf-8")

    img = np.frombuffer(base64.decodebytes(img), dtype=dtype)

    img = img.reshape(shape)

    return img

In addition, the image to be predicted needs simple preprocessing. The preprocessing function is defined as follows:

def prepare_image(image, target):

# if the image mode is not RGB, convert it

if image.mode != "RGB":

image = image.convert("RGB")

# resize the input image and preprocess it

image = image.resize(target)

image = img_to_array(image)

# expand image as one batch like shape (1, c, w, h)

image = np.expand_dims(image, axis=0)

image = imagenet_utils.preprocess_input(image)

# return the processed image

return image

1.3 Definition of prediction interface

    After the preparation work is completed, the next two parts are the main two parts: the model prediction part and the app back-end response part. First define the model prediction function as follows:

def classify_process():
    # 导入模型
    print("* Loading model...")
    model = ResNet50(weights="imagenet")
    print("* Model loaded")
    while True:
        # 从数据库中创建预测图像队列
        queue = db.lrange(IMAGE_QUEUE, 0, BATCH_SIZE - 1)
        imageIDs = []
        batch = None
        # 遍历队列
        for q in queue:
            # 获取队列中的图像并反序列化解码
            q = json.loads(q.decode("utf-8"))
            image = base64_decode_image(q["image"], IMAGE_DTYPE,
                                        (1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANS))
            # 检查batch列表是否为空
            if batch is None:
                batch = image
            # 合并batch
            else:
                batch = np.vstack([batch, image])
            # 更新图像ID
            imageIDs.append(q["id"])
         if len(imageIDs) > 0:
            print("* Batch size: {}".format(batch.shape))
            preds = model.predict(batch)
            results = imagenet_utils.decode_predictions(preds)
            # 遍历图像ID和预测结果并打印
            for (imageID, resultSet) in zip(imageIDs, results):
                # initialize the list of output predictions
                output = []
                # loop over the results and add them to the list of
                # output predictions
                for (imagenetID, label, prob) in resultSet:
                    r = {"label": label, "probability": float(prob)}
                    output.append(r)
                # 保存结果到数据库
                db.set(imageID, json.dumps(output))
            # 从队列中删除已预测过的图像
            db.ltrim(IMAGE_QUEUE, len(imageIDs), -1)
        time.sleep(SERVER_SLEEP)

Then define the app service:

@app.route("/predict", methods=["POST"])
def predict():
    # 初始化数据字典
    data = {"success": False}
    # 确保图像上传方式正确
    if flask.request.method == "POST":
        if flask.request.files.get("image"):
            # 读取图像数据
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image))
            image = prepare_image(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
            # 将数组以C语言存储顺序存储
            image = image.copy(order="C")
            # 生成图像ID
            k = str(uuid.uuid4())
            d = {"id": k, "image": base64_encode_image(image)}
            db.rpush(IMAGE_QUEUE, json.dumps(d))
            # 运行服务
            while True:
                # 获取输出结果
                output = db.get(k)
                if output is not None:
                    output = output.decode("utf-8")
                    data["predictions"] = json.loads(output)
                    db.delete(k)
                    break
                time.sleep(CLIENT_SLEEP)
            data["success"] = True
        return flask.jsonify(data)

  Flask uses the Python decorator to automatically associate the requested URL with the target function internally, so that we can quickly build a Web service.

1.4 Interface test

     After the service is set up, we can use a picture to test the effect:

curl -X POST -F [email protected] 'http://127.0.0.1:5000/predict'

Return from the model:

The prediction result is returned:

  Finally, we can perform a stress test on the built service to see how the service's concurrency and other performance are, and define a stress test file stress_test.py as follows:

from threading import Thread
import requests
import time
# 请求的URL
KERAS_REST_API_URL = "http://127.0.0.1:5000/predict"
# 测试图片
IMAGE_PATH = "test.jpg"
# 并发数
NUM_REQUESTS = 500
# 请求间隔
SLEEP_COUNT = 0.05
def call_predict_endpoint(n):
    # 上传图像
    image = open(IMAGE_PATH, "rb").read()
    payload = {"image": image}
    # 提交请求
    r = requests.post(KERAS_REST_API_URL, files=payload).json()
    # 确认请求是否成功
    if r["success"]:
        print("[INFO] thread {} OK".format(n))
    else:
        print("[INFO] thread {} FAILED".format(n))
# 多线程进行
for i in range(0, NUM_REQUESTS):
    # 创建线程来调用api
    t = Thread(target=call_predict_endpoint, args=(i,))
    t.daemon = True
    t.start()
    time.sleep(SLEEP_COUNT)
time.sleep(300)

The test results are as follows:

Guess you like

Origin blog.csdn.net/wzhrsh/article/details/109550308