0009-flask调用pytorch模型

# -*- encoding: utf-8 -*-
"""
@File    : flask_torch.py
@Time    : 2020/07/12 11:59
@Author  : Johnson
@Email   : [email protected]
"""
import io
import json
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms as T
from torchvision.models import resnet50

#初始化flask应用
app = flask.Flask(__name__)
model = None
use_gpu = True

with open("class.txt",'r') as f:
    idx2label = eval(f.read())


def load_model():
    """load the pre-trained model,you can used your model just as easy"""
    global model
    model = resnet50(pretrained=True)
    model.eval()
    if use_gpu:
        model.cuda()


def prepare_image(image,target_size):
    '''
    对图片进行预处理
    '''
    if image.mode!="RGB":
        image = image.convert("RGB")

    #resize the image
    image = T.resize(target_size)(image)
    image = T.toTensor()(image)

    #转化为Tensor格式和归一化处理
    image = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    #add batch_size axis
    image = image[None]
    if use_gpu:
        image = image.cuda()

    return torch.autograd.Variable(image,volatile=True)

@app.route("/predict",methods=["POST"])
def predict():
    #initialize the data dic. that will be retured from the view
    data = {"success",False}

    #ensure the image was properly uploaded to out endpoint
    if flask.request.method=="POST":
        #read the image in PIL Image
        image = flask.request.files["image"].read()
        image = Image.open(io.BytesIO(image))

        #process the image and prepare it for classification
        image = prepare_image(image,target_size=(224,224))

        #预测
        preds = F.softmax(model(image),dim=1)
        results = torch.topk(preds.cpu().data,k=1,dim=1)

        #
        data["predictions"] = list()

        # Loop over the results and add them to the list of returned predictions
        for prob, label in zip(results[0][0], results[1][0]):
            prob = float(prob.item())
            label = int(label.item())
            label_name = idx2label[label]
            r = {"label": label_name, "probability": float(prob)}
            data['predictions'].append(r)

        # Indicate that the request was a success.
        data["success"] = True

        # Return the data dictionary as a JSON response.
    return flask.jsonify(data)

if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    load_model()
    app.run(debug=True)

猜你喜欢

转载自blog.csdn.net/zhonglongshen/article/details/112725920