Label Studio+Yolov5 实现目标检测预标注(二)

政采云技术团队.png

余笙.png

引言

之前我们介绍了如何创建待标注的项目以及如何训练针对自己数据集的 Yolov5 模型。下面我们将介绍通过继承 label_studio_ml.model.LabelStudioMLBase 类,来构造我们的后端预测模型类。

一.构造预测模型类

Label Studio ml 提供了几个例子来实现预测模型与 Label Studio 前端的连接。下面我们将通过修改官方提供的例子,来实现 yolov5 预测模型类。 首先我们先打开 mmdetection.py 文件, 具体路径为: label-studio-ml-backend/label_studio_ml/examples/mmdetection/mmdetection.py。

import os
import logging
import boto3
import io
import json

from mmdet.apis import init_detector, inference_detector

from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.utils import get_image_size, \
    get_single_tag_keys, DATA_UNDEFINED_NAME
from label_studio_tools.core.utils.io import get_data_dir
from botocore.exceptions import ClientError
from urllib.parse import urlparse


logger = logging.getLogger(__name__)


class MMDetection(LabelStudioMLBase):
    """Object detector based on https://github.com/open-mmlab/mmdetection"""

    def __init__(self, config_file="../mmdetection/config_file/",
                 checkpoint_file="../mmdetection/checkpoint_file/",
                 image_dir=None,
                 labels_file=None, score_threshold=0.3, device='cpu', **kwargs):
        """
        Load MMDetection model from config and checkpoint into memory.
        (Check https://mmdetection.readthedocs.io/en/v1.2.0/GETTING_STARTED.html#high-level-apis-for-testing-images)

        Optionally set mappings from COCO classes to target labels
        :param config_file: Absolute path to MMDetection config file (e.g. /home/user/mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x.py)
        :param checkpoint_file: Absolute path MMDetection checkpoint file (e.g. /home/user/mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth)
        :param image_dir: Directory where images are stored (should be used only in case you use direct file upload into Label Studio instead of URLs)
        :param labels_file: file with mappings from COCO labels to custom labels {"airplane": "Boeing"}
        :param score_threshold: score threshold to wipe out noisy results
        :param device: device (cpu, cuda:0, cuda:1, ...)
        :param kwargs:
        """
        super(MMDetection, self).__init__(**kwargs)
        config_file = config_file or os.environ['config_file']
        checkpoint_file = checkpoint_file or os.environ['checkpoint_file']
        self.config_file = config_file
        self.checkpoint_file = checkpoint_file
        self.labels_file = labels_file
        # default Label Studio image upload folder
        upload_dir = os.path.join(get_data_dir(), 'media', 'upload')
        self.image_dir = image_dir or upload_dir
        logger.debug(f'{self.__class__.__name__} reads images from {self.image_dir}')
        if self.labels_file and os.path.exists(self.labels_file):
            self.label_map = json_load(self.labels_file)
        else:
            self.label_map = {}

        self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys(
            self.parsed_label_config, 'RectangleLabels', 'Image')
        schema = list(self.parsed_label_config.values())[0]
        self.labels_in_config = set(self.labels_in_config)

        # Collect label maps from `predicted_values="airplane,car"` attribute in <Label> tag
        self.labels_attrs = schema.get('labels_attrs')
        if self.labels_attrs:
            for label_name, label_attrs in self.labels_attrs.items():
                for predicted_value in label_attrs.get('predicted_values', '').split(','):
                    self.label_map[predicted_value] = label_name

        print('Load new model from: ', config_file, checkpoint_file)
        self.model = init_detector(config_file, checkpoint_file, device=device)
        self.score_thresh = score_threshold

    def _get_image_url(self, task):
        image_url = task['data'].get(self.value) or task['data'].get(DATA_UNDEFINED_NAME)
        if image_url.startswith('s3://'):
            # presign s3 url
            r = urlparse(image_url, allow_fragments=False)
            bucket_name = r.netloc
            key = r.path.lstrip('/')
            client = boto3.client('s3')
            try:
                image_url = client.generate_presigned_url(
                    ClientMethod='get_object',
                    Params={'Bucket': bucket_name, 'Key': key}
                )
            except ClientError as exc:
                logger.warning(f'Can\'t generate presigned URL for {image_url}. Reason: {exc}')
        return image_url

    def predict(self, tasks, **kwargs):
        assert len(tasks) == 1
        task = tasks[0]
        image_url = self._get_image_url(task)
        image_path = self.get_local_path(image_url)
        model_results = inference_detector(self.model, image_path)
        results = []
        all_scores = []
        img_width, img_height = get_image_size(image_path)
        for bboxes, label in zip(model_results, self.model.CLASSES):
            output_label = self.label_map.get(label, label)

            if output_label not in self.labels_in_config:
                print(output_label + ' label not found in project config.')
                continue
            for bbox in bboxes:
                bbox = list(bbox)
                if not bbox:
                    continue
                score = float(bbox[-1])
                if score < self.score_thresh:
                    continue
                x, y, xmax, ymax = bbox[:4]
                results.append({
                    'from_name': self.from_name,
                    'to_name': self.to_name,
                    'type': 'rectanglelabels',
                    'value': {
                        'rectanglelabels': [output_label],
                        'x': x / img_width * 100,
                        'y': y / img_height * 100,
                        'width': (xmax - x) / img_width * 100,
                        'height': (ymax - y) / img_height * 100
                    },
                    'score': score
                })
                all_scores.append(score)
        avg_score = sum(all_scores) / max(len(all_scores), 1)
        return [{
            'result': results,
            'score': avg_score
        }]


def json_load(file, int_keys=False):
    with io.open(file, encoding='utf8') as f:
        data = json.load(f)
        if int_keys:
            return {int(k): v for k, v in data.items()}
        else:
            return data
复制代码

mmdetection.py 文件定义了 mmdetection 预测模型类,该类包含了4个函数,分别为构造函数、获取图片的函数 _get_image_url(self, task)、获取预测结果的函数 predict(self, tasks, kwargs) 以及 json 文件读取的函数 json_load(file, int_keys=False)。我们需要对构造函数以及获取预测结果的函数进行修改。具体如下:

class MyModel(LabelStudioMLBase):

    def __init__(self, image_dir=None, labels_file=None, device='cpu', **kwargs):
        '''
        loading models to global objects
        '''
        super(MyModel, self).__init__(**kwargs)

        # 封装好的推理模型
        self.detector  = Yolo5Engine('_source_/bans8')
        upload_dir = os.path.join(get_data_dir(), 'media', 'upload')
        self.image_dir = image_dir or upload_dir
        self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys(
            self.parsed_label_config, 'RectangleLabels', 'Image')
        schema = list(self.parsed_label_config.values())[0]
        self.labels_in_config = set(self.labels_in_config)

    def _get_image_url(self, task):
        image_url = task['data'].get(self.value) or task['data'].get(DATA_UNDEFINED_NAME)
        if image_url.startswith('s3://'):
            # presign s3 url
            r = urlparse(image_url, allow_fragments=False)
            bucket_name = r.netloc
            key = r.path.lstrip('/')
            client = boto3.client('s3')
            try:
                image_url = client.generate_presigned_url(
                    ClientMethod='get_object',
                    Params={'Bucket': bucket_name, 'Key': key}
                )
            except ClientError as exc:
                # logger.warning(f'Can\'t generate presigned URL for {image_url}. Reason: {exc}')
                pass
        return image_url


    def predict(self, tasks, **kwargs):
        # assert len(tasks) == 1
        task = tasks[0]
        image_url = self._get_image_url(task)
        image_path = self.get_local_path(image_url)
        self.boxes = []
        self.labels = []
        self.confs = []
        for i in range(len(self.detector.detection(image_path))):
            self.boxes.append(self.detector.detection(image_path)[i]["box"])
            self.labels.append(self.detector.detection(image_path)[i]["key"])
            self.confs.append(self.detector.detection(image_path)[i]["score"])
        results = []  # results需要放在list中再返回
        img_width, img_height = get_image_size(image_path)  # 用于转换坐标

        for id, bbox in enumerate(self.boxes):
            label = self.labels[id]  # 注意:后端和前端标签名保持一致
            conf = self.confs[id]  # 若模型不返回置信度,可指定一个默认值
            x, y, x2, y2 = bbox
            w = x2-x
            h = y-y2
            y = y2

            if label not in self.labels_in_config:
                print(label + ' label not found in project config.')
                continue  # 忽略前端不需要的类别

            results.append({
                'from_name': self.from_name,
                'to_name': self.to_name,
                'type': 'rectanglelabels',
                'value': {
                    'rectanglelabels': [label],
                    'x': int(x / img_width * 100),  # 坐标需要转换
                    'y': int(y / img_height * 100),  # 数值类型返回整型
                    'width': int(w / img_width * 100),
                    'height': int(h / img_height * 100)
                },
                'score': int(conf * 100)  #
            })

        avg_score = int(sum(self.confs) / max(len(self.confs), 1))
        return [{
            'result': results,
            'score': avg_score
        }]
复制代码

在修改 predict 函数时,我们发现 mmdetection 模型输出的 x 和 y 为 yolov5 模型输出的 box 的左上点的 x 和右下点的 y,这是需要注意的。

二.启动ML后端

启动 ML 后端的命令如下:

label-studio-ml init my_ml_backend --script label_studio_ml/examples/simple_text_classifier/simple_text_classifier.py
label-studio-ml start my_ml_backend
复制代码

其中第一行命令为初始化ML后端,--script 为 ML 后端类的位置。第二行命令为启动 ML 后端。启动成功后会出现如下提示:

其默认端口为 9090,如果想要更改端口,可以在 _wsgi.py 文件中将 9090 更改为其他端口。

f __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Label studio')
    parser.add_argument(
        '-p', '--port', dest='port', type=int, default=9090,
        help='Server port')
    parser.add_argument(
        '--host', dest='host', type=str, default='0.0.0.0',
        help='Server host')
    parser.add_argument(
        '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
        help='Additional LabelStudioMLBase model initialization kwargs')
    parser.add_argument(
        '-d', '--debug', dest='debug', action='store_true',
        help='Switch debug mode')
    parser.add_argument(
        '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
        help='Logging level')
    parser.add_argument(
        '--model-dir', dest='model_dir', default=os.path.dirname(__file__),
        help='Directory where models are stored (relative to the project directory)')
    parser.add_argument(
        '--check', dest='check', action='store_true',
        help='Validate model instance before launching server')
复制代码

三.Label StudioML 设置

启动 ML 后端服务后,我们还需要在 label studio 前端的项目页面进行设置,具体步骤如下:

  1. 打开项目设置界面选到 Machine Learning 并选择 Add Model 来添加模型。

  1. 在 Title 写上模型的名字,在 URL 上填写模型的地址,具体地址可以在 ML 后端服务开启时可以看到。

  1. 当构建成功时会显示 Connected。

  1. 回项目就可以查看预标注的结果。

四.总结

本文介绍了 Label Studio 和 Yolov5 如何实现目标检测预标注。 除此以外,Label Studio 还可以实现音频和文本的预标注,具体例子可以参考官网:labelstud.io/guide/。以上就是本文主要内容,如有问题,欢迎指正!

推荐阅读

Label Studio+Yolov5 实现目标检测预标注(一)

业务交互网关洪峰应对之道

Flink checkpoint 算法(上)

从线上死锁分析到 Next-Key Lock 理解

招贤纳士

政采云技术团队(Zero),一个富有激情、创造力和执行力的团队,Base 在风景如画的杭州。团队现有300多名研发小伙伴,既有来自阿里、华为、网易的“老”兵,也有来自浙大、中科大、杭电等校的新人。团队在日常业务开发之外,还分别在云原生、区块链、人工智能、低代码平台、中间件、大数据、物料体系、工程平台、性能体验、可视化等领域进行技术探索和实践,推动并落地了一系列的内部技术产品,持续探索技术的新边界。此外,团队还纷纷投身社区建设,目前已经是 google flutter、scikit-learn、Apache Dubbo、Apache Rocketmq、Apache Pulsar、CNCF Dapr、Apache DolphinScheduler、alibaba Seata 等众多优秀开源社区的贡献者。如果你想改变一直被事折腾,希望开始折腾事;如果你想改变一直被告诫需要多些想法,却无从破局;如果你想改变你有能力去做成那个结果,却不需要你;如果你想改变你想做成的事需要一个团队去支撑,但没你带人的位置;如果你想改变本来悟性不错,但总是有那一层窗户纸的模糊……如果你相信相信的力量,相信平凡人能成就非凡事,相信能遇到更好的自己。如果你希望参与到随着业务腾飞的过程,亲手推动一个有着深入的业务理解、完善的技术体系、技术创造价值、影响力外溢的技术团队的成长过程,我觉得我们该聊聊。任何时间,等着你写点什么,发给 [email protected]

微信公众号

文章同步发布,政采云技术团队公众号,欢迎关注

政采云技术团队.png

猜你喜欢

转载自juejin.im/post/7132263223817928734