【Pytorch框架flask部署简单例子—图像识别分类】

写在前面

Flask是一种用Python编写的轻量级Web框架,可以帮助您快速构建Web应用程序。

如果我们正在使用PyTorch框架开发深度学习应用程序,并希望将其部署到Web服务器上,则可以使用Flask框架实现。本文将介绍如何使用Flask对前一篇博客中所编写的基于PyTorch框架的图像分类模型进行本地部署,共包含两个py文件(flask_server.py和flask_predict.py),分别表示服务端和客户端,以实现对该模型的远程访问和使用,下文将会详细介绍。(点击这里:基于PyTorch实现经典网络架构的花卉图像分类模型

在使用Flask部署PyTorch应用程序之前,需要在本地计算机上安装Flask库,若pip install flask下载速度过慢,可换成conda install flask(安装了anaconda3),就能很快下载完毕。
Alt

1.flask_server服务端

1.1 初始化flask app

创建一个名为app的Flask对象,并将__name__作为参数传递给它(__name__是一个特殊变量,它表示当前模块的名称,通常用于确定应用程序根目录的位置)。接着创建一个名为model的变量,并将其初始化为None,该变量将用于存储训练好的PyTorch模型。再创建一个名为use_gpu的布尔变量,并将其初始化为False,这个变量将用于控制是否使用GPU加速模型的计算(GPU不错的小伙伴建议为True)。

初始化的流程较为固定,可作为模板进行使用,代码如下:

app = flask.Flask(__name__)
model = None
use_gpu = False

1.2 加载模型

定义一个load_model函数,传入训练模型model、相应结构和参数。需要注意的是,model的值需与训练时所用模型相同(重要!!),同时将model声明为全局变量。

接着重新定义全连接层(102表示最后输出的类别,需根据自身任务来确定),再加载best.pth文件(best.pth存储着训练时效果最好的参数,与前篇博客是同一文件),再使用model.load_state_dict()函数将保存的模型参数加载到我们定义的模型中.

最后使用model.eval()函数将模型设置为验证模式,这将禁用例如dropout和batch normalization等一些训练时的策略,输出分类的概率值。

def load_model():
    """Load the pre-trained model, you can use your model just as easily.
    """
    global model
    #这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息
    model = models.resnet18()
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 类别数自己根据自己任务来
    #print(model)
    checkpoint = torch.load('best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    #将模型指定为测试格式
    model.eval()
    #是否使用gpu
    if use_gpu:
        model.cuda()

1.3 数据预处理

数据预处理部分大致与验证集相似。不同之处在于添加了一个格式转换,有可能请求端所给image的格式不同,因此我们需要将其统一至RGB格式(训练时所用格式)。

def prepare_image(image, target_size):
    #针对不同模型,image的格式不同,但需要统一至RGB格式
    if image.mode != 'RGB':
        image = image.convert("RGB")

    # Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
    image = transforms.Resize(target_size)(image)
    image = transforms.ToTensor()(image)

    # Convert to Torch.Tensor and normalize. mean与std   (RGB三通道)这里的参数和数据集中是对应的,训练过程中一致
    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    # Add batch_size axis.增加一个维度,用于按batch测试   本次这里一次测试一张
    image = image[None]
    if use_gpu:
        image = image.cuda()
    return Variable(image, volatile=True) #不需要求导

1.4 开启服务

定义一个predict函数用于接收POST请求并进行图像预测的Flask路由处理函数。当POST请求中包含一个名为“image”的文件时,该函数将读取该文件并使用预处理函数prepare_image()进行图像预处理。然后将预处理后的图像作为输入传递给已加载的PyTorch模型,使用softmax函数对预测结果进行归一化,选取前3个最高概率的结果,将它们以标签和概率的形式打包成字典,存入data字典的“predictions”列表中,最终以JSON格式返回该data字典。如果请求成功,则“success”键将被设置为True,代码如下:

@app.route("/predict", methods=["POST"])
def predict():
    # Initialize the data dictionary that will be returned from the view.
    #做一个标志,刚开始无图像传入时为false,传入图像时为true
    data = {
    
    "success": False}

    # 如果收到请求
    if flask.request.method == 'POST':
        #判断是否为图像
        if flask.request.files.get("image"):
            # Read the image in PIL format
            # 将收到的图像进行读取
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image)) #二进制数据

            # 利用上面的预处理函数将读入的图像进行预处理
            image = prepare_image(image, target_size=(64, 64))

            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())

            #将data字典增加一个key,value,其中value为list格式
            data['predictions'] = list()

            # Loop over the results and add them to the list of returned predictions
            for prob, label in zip(results[0][0], results[1][0]):
                #label_name = idx2label[str(label)]
                r = {
    
    "label": str(label), "probability": float(prob)}
                #将预测结果添加至data字典
                data['predictions'].append(r)

            # Indicate that the request was a success.
            data["success"] = True
    # 将最终结果以json格式文件传出
    return flask.jsonify(data)

在最后加上下段代码,这段代码的作用是在服务器启动时加载PyTorch模型,然后启动Flask服务器,监听端口号为5012(自己定义)的请求。

if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    #先加载模型
    load_model()
    #再开启服务
    app.run(port='5012')

2.flask_predict客户端

# url和端口写成自己的
flask_url = 'http://127.0.0.1:5012/predict'


def predict_result(image_path):
    #啥方法都行
    image = open(image_path, 'rb').read()
    payload = {
    
    'image': image}
    #request发给server.
    r = requests.post(flask_url, files=payload).json()

    # 成功的话在返回.
    if r['success']:
        # 输出结果.
        for (i, result) in enumerate(r['predictions']):
            print('{}. {}: {:.4f}'.format(i + 1, result['label'],
                                          result['probability']))
    # 失败了就打印.
    else:
        print('Request failed')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Classification demo')
    parser.add_argument('--file', default='./flower_data/train_filelist/image_06998.jpg', type=str, help='test image file')

    args = parser.parse_args()
    predict_result(args.file)

客户端代码较简单,flask_url的由来下文会讲。首先读取待预测图像,将其转换为二进制格式,并构造请求payload。接着发送POST请求至服务端,等待服务端返回结果,若服务端返回的JSON文件中success字段为True,则表示预测成功,将预测结果输出;否则表示预测失败,输出相应信息。此外可以通过命令行参数–file指定待预测图像的路径。

3.执行流程

首先,打开pycharm下部terminal面板,输出python+服务端文件(python flask_server.py),即会启动服务,并打印相应英文信息,如下图。
在这里插入图片描述
接着,将file属性中地址修改为所要预测图片的地址。
在这里插入图片描述
最后,点击运行flask_predict.py文件,即可输出结果。
在这里插入图片描述
结果如下:
在这里插入图片描述

结尾

各位小伙伴可以关注博主,博客内所涉及代码与数据集,私聊博主可以免费发给大家哦。

猜你喜欢

转载自blog.csdn.net/fly_ddaa/article/details/130437892