图像万物分割——Segment Anything算法解析与模型推理

一、概述

在视觉任务中,图像分割任务是一个很广泛的领域,应用于交互式分割,边缘检测,超像素化,感兴趣目标生成,前景分割,语义分割,实例分割,泛视分割等。
交互式分割,这种分割任务,它允许用户手动细化掩码来分割任意类型的对象。然而,这种方法需要用户的不断参与和指导,类似于ps里面的抠图快速选择工具。
在这里插入图片描述

实例分割任务是它能够自动分割特定类别的对象,例如行人,狗,电视或椅子,但需要大量的手动标注数据,标注样本要以上万个样本,然后要经过大量的计算资源和代码算法知识来训练模型。这种方式应用最广泛应该是人像自动抠图:
在这里插入图片描述
为了 解决这些分割任务的局限性,Meta 推出了「分割一切」AI 算法Segment Anything,为分割任务提供一种通用的、全自动的分割解决方案。

二、Segment Anything 万物分割

1.算法摘要

作者介绍了Segment Anything (SA) 项目,这是一个旨在进行图像分割的新任务,同时提供了相应的模型和数据集。在该项目中,作者采用了一种高效的模型来进行数据收集,以构建迄今为止最大的分割数据集。他们在超过1,100万张公开的图像上进行了标注,生成了超过10亿个掩码。
这个模型(SAM)训练完成之后,以使其具备"promptable"(可提示)的性质,因此意味着它可以零样本(zero-shot)地适应新的数据集和任务,而无需先对数据进行标注和训练。作者对该模型进行了广泛的评估,发现它在许多任务上的零样本表现通常与完全监督的性能相媲美,甚至更好。作者公开了他们的模型(SAM),还发布了相应的图像数据集(SA-1B)。
在这里插入图片描述

2. 算法介绍

LLM的出现,让研究人员感受到,使用互联网规模的数据集上预训练的大型语言模型已经改变了自然语言处理(NLP)领域,因为它们表现出强大的零样本和少样本泛化能力,可以应对未在训练中出现的任务和数据分布。这种泛化通常通过提示工程(prompt engineering)来实现,其中手工制作的文本提示可以引导语言模型生成有效的文本响应。这些基础模型在使用丰富的互联网文本语料库进行预训练时,表现出令人惊讶的零样本和少样本性能,有时甚至可以与经过精细调整(fine-tune)的模型相媲美。研究经验表明,这种零样本和少样本性能会随着模型规模、数据集大小和总训练计算量的增加而改善。

在计算机视觉领域也在探索基础模型的应用,例如,CLIP和ALIGN使用对比学习来训练文本和图像编码器,经过训练后,这些编码器可以用于零样本泛化到新的视觉概念和数据分布。这些编码器还可以有效地与其他模块结合,用于解决下游任务,比如图像生成。然而,计算机视觉领域涉及的问题远不止这些,而且许多问题缺乏丰富的训练数据。

在这项研究工作中,SAM作者的目标是建立一个图像分割的基础模型,也就是一个可提示的模型,它可以在广泛的数据集上进行预训练以实现强大的泛化能力。一旦有了这个模型,作者进一步探索如何通过快速流程来解决各种新的数据分布上的下游分割问题。

这个计划的成功取决于三个关键要素:任务、模型和数据。作者需要解决以下关于图像分割的问题:

  1. 什么样的视觉分割任务可以实现零样本泛化?

  2. 为了实现这一个分割任务,对应的模型架构应该是什么样的?

  3. 哪些数据可以支持这个任务和模型的预训练?

作者首先定义了一个可提示的分割任务,这个任务足够通用,可以作为强大的预训练目标,同时也可以支持广泛的下游应用。这个任务要求一个支持多种提示的模型,并能够实时生成分割掩码,以支持交互式使用。然而,互联网上目前尚没有足够大规模的分割数据集来满足这个任务的需求。作者提出了“数据引擎”来应对这个问题,即通过模型辅助数据收集和不断迭代来改进数据,以填补数据的不足。这个方法可以在模型训练和数据收集之间进行交互,以实现更好的性能。

  • 分割任务
    在自然语言处理和计算机视觉领域,基础模型具有很大的前景,因为它们可以用于执行零样本学习和少样本学习,通过利用提示来适应新的数据集和任务。受到这种思路的启发,本文提出了一个称为"可提示分割任务"的新领域,其主要目标是在给定分割提示的情况下生成有效的分割掩码(如图1a所示)。

这些分割提示可以简单地指定图像中要分割的对象,例如,提示可以包括对象的位置信息或文本描述。
这里的"有效输出掩码"意味着,即使提示信息模糊不清,可能指向多个不同对象(例如,在图像上一个点可能表示衬衫或穿衬衫的人),生成的分割掩码也应该合理,至少应该包括这些对象中的一个。

在这项研究中,作者将可提示分割任务作为预训练目标,然后使用提示工程方法来解决各种不同的下游分割任务。这种方法有望为计算机视觉领域带来一种强大的学习范式,可以在面对新任务时从有限的提示信息中进行学习,而不需要大量的标记数据。这对于处理多样化和复杂的视觉任务可能具有很大的潜力。

  • 模型选择
    可提示分割任务对模型的架构提出了一些严格的要求,这包括对提示的支持灵活性、实时计算的需求,以便允许交互使用,以及能够处理歧义。作者提出了一个简单的模型设计,可以满足这些要求,被称为"Segment Anything"模型,简称SAM(见图1b)。SAM的架构包括以下组成部分:
  • 图像编码器:这是一个强大的模型,负责将输入图像转化为图像嵌入(image embedding),以捕捉图像的特征信息。
  • 提示编码器:这是一个用于嵌入提示信息的模型,它将提示信息转化为提示嵌入,以使模型能够理解提示中的内容。
  • 控码解码器:这是一个轻量级的模型,负责将图像嵌入和提示嵌入结合,然后预测分割掩码。这一部分的设计使得SAM可以实现对相同图像嵌入的不同提示信息的分配,从而使模型能够处理多样性的提示。

SAM的设计还允许它在不超过50毫秒的时间内从提示符中预测掩码,实现了实时性能,这对于实际应用和交互式任务非常重要。

作者的主要关注点包括边界框、关键点和分割掩码提示。为了解决歧义问题,SAM被设计成能够预测多个掩码,即使给定相同的提示。这使得SAM可以自然地处理提示中的歧义,比如前文提到的衬衫和穿衬衫的人之间的歧义示例。这个能力对于处理复杂的图像场景和多义性提示非常有帮助。

  • 数据引擎
    为了使SAM能够在新的数据分布上实现强大的泛化能力,需要在一个大型数据集上进行训练,该数据集应该覆盖各种不同的分割任务和场景。然而,典型的训练方法通常依赖于在线获取数据,而掩码标注信息通常相对稀缺,因此需要采用替代策略。作者提出的解决方案是构建一个称为"数据引擎"的系统,这个引擎包括三个主要阶段:辅助手动、半自动和全自动。
  1. 辅助手动阶段:在这个阶段,SAM与人工注释人员协作,类似于传统的交互式分割设置。人工注释人员手动为图像中的对象生成掩码,同时SAM提供辅助信息,例如提示信息,以帮助人工注释人员完成掩码的生成。这一阶段有助于收集一些基本的分割标注。

  2. 半自动阶段:在这个阶段,SAM能够自动为图像中的对象的某些子区域生成掩码。它会根据已有的掩码和提示信息,自动预测可能的对象位置,并生成相应的掩码。这减轻了人工注释人员的工作负担,因为他们可以专注于注释剩余的对象,从而提高了标注的多样性。

  3. 全自动阶段:在最后一个阶段,作者采用一种规则网格提示SAM,用于生成大量高质量掩码。这个提示方法能够为每张图像平均产生约100个掩码,以增加数据的多样性和覆盖不同的情况。

通过这种数据引擎的阶段性设计,作者能够有效地利用协作注释和自动化方法,以构建一个大规模的数据集,为SAM的训练提供了足够丰富和多样的标注数据,从而使其在新的数据分布上实现强大的泛化能力。这种方法有助于克服标注数据稀缺性的问题,尤其是对于复杂的分割任务。

3.数据集

作者最终的数据集SA-1B,包括来自1100万张经许可和隐私保护图像的超过10亿个掩码(见图2)。SA-1B使用作者的数据引擎的最后阶段完全自动收集,比现有的最大分割数据集拥有400多倍的掩码,并且作者广泛验证,掩码具有高质量和多样性。作者希望SA-1B能够成为一种有价值的资源,用于建立新的基础模型。

4.实验

作者广泛地评估SAM。首先,在23个分割数据集上的测试,作者发现SAM从单个前景点生成了高质量的掩码,通常仅略低于手动注释的真实值。其次,作者在使用提示工程的零样本传输协议(zero-shot transfer protocol)下的各种下游任务上发现了持续强大的定量和定性结果,包括边缘检测、感兴趣目标生成、实例分割和文本到掩码预测。这些结果表明,SAM可以在即时工程中开箱即用,解决涉及SAM训练数据之外的图像分布的各种任务。

三、模型C++推理

1.实现代码

#include "include/segment_anything.h"
namespace sam{
    
    
SegmentAnything::~SegmentAnything()
{
    
    
    image_encoder_net_.clear();
    mask_decoder_net_.clear();
}

static inline float intersection_area(const sam_result_t& a, const sam_result_t& b)
{
    
    
    cv::Rect_<float> inter = a.box & b.box;
    return inter.area();
}

static void qsort_descent_inplace(std::vector<sam_result_t>& faceobjects, int left, int right)
{
    
    
    int i = left;
    int j = right;
    float p = faceobjects[(left + right) / 2].iou_pred;

    while (i <= j)
    {
    
    
        while (faceobjects[i].iou_pred > p)
            i++;

        while (faceobjects[j].iou_pred < p)
            j--;

        if (i <= j)
        {
    
    
            // swap
            std::swap(faceobjects[i], faceobjects[j]);

            i++;
            j--;
        }
    }

    #pragma omp parallel sections
    {
    
    
        #pragma omp section
        {
    
    
            if (left < j) qsort_descent_inplace(faceobjects, left, j);
        }
        #pragma omp section
        {
    
    
            if (i < right) qsort_descent_inplace(faceobjects, i, right);
        }
    }
}

static void qsort_descent_inplace(std::vector<sam_result_t>& faceobjects)
{
    
    
    if (faceobjects.empty())
        return;

    qsort_descent_inplace(faceobjects, 0, faceobjects.size() - 1);
}

static void nms_sorted_bboxes(const cv::Mat& bgr,const std::vector<sam_result_t>& faceobjects, std::vector<int>& picked, float nms_threshold)
{
    
    
    picked.clear();

    const int n = faceobjects.size();

    std::vector<float> areas(n);
    for (int i = 0; i < n; i++)
    {
    
    
        areas[i] = faceobjects[i].box.area();
    }
    cv::Mat img = bgr.clone();
    for (int i = 0; i < n; i++)
    {
    
    
        const sam_result_t& a = faceobjects[i];

        int keep = 1;
        for (int j = 0; j < (int)picked.size(); j++)
        {
    
    
            const sam_result_t& b = faceobjects[picked[j]];

            // intersection over union
            float inter_area = intersection_area(a, b);
            float union_area = areas[i] + areas[picked[j]] - inter_area;
            // float IoU = inter_area / union_area
            if (inter_area / union_area > nms_threshold){
    
    
                keep = 0;
            }
                
        }

        if (keep)
            picked.push_back(i);
    }
}
int SegmentAnything::NMS(const cv::Mat& bgr, std::vector<sam_result_t>& proposals, std::vector<int>& picked, float nms_threshold)
{
    
    
    qsort_descent_inplace(proposals);
    nms_sorted_bboxes(bgr, proposals, picked, nms_threshold);
    
    return 0;
}

int SegmentAnything::Load(const std::string& image_encoder_param, const std::string& image_encoder_bin, const std::string& mask_decoder_param, const std::string& mask_decoder_bin)
{
    
    
    int ret = 0;
    ret = image_encoder_net_.load_param(image_encoder_param.c_str());
    if (ret < 0)
        return -1;
    ret = image_encoder_net_.load_model(image_encoder_bin.c_str());
    if (ret < 0)
        return -1;
    ret = mask_decoder_net_.load_param(mask_decoder_param.c_str());
    if (ret < 0)
        return -1;
    ret = mask_decoder_net_.load_model(mask_decoder_bin.c_str());
    if (ret < 0)
        return -1;

    return 0;
}
int SegmentAnything::ImageEncoder(const cv::Mat& bgr, ncnn::Mat& image_embeddings, image_info_t& image_info)
{
    
    
    const int target_size = 1024;
    int img_w = bgr.cols;
    int img_h = bgr.rows;

    int w = img_w;
    int h = img_h;
    float scale = 1.f;
    if (w > h)
    {
    
    
        scale = (float)target_size / w;
        w = target_size;
        h = h * scale;
    }
    else
    {
    
    
        scale = (float)target_size / h;
        h = target_size;
        w = w * scale;
    }

    ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h);

    int wpad = target_size - w;
    int hpad = target_size - h;
    ncnn::Mat in_pad;
    ncnn::copy_make_border(in, in_pad, 0, hpad, 0, wpad, ncnn::BORDER_CONSTANT, 0.f);

    in_pad.substract_mean_normalize(means_, norms_);

    ncnn::Extractor image_encoder_ex = image_encoder_net_.create_extractor();

    image_encoder_ex.input("image", in_pad);
    image_encoder_ex.extract("image_embeddings", image_embeddings);

    image_info.img_h = img_h;
    image_info.img_w = img_w;
    image_info.pad_h = h;
    image_info.pad_w = w;
    image_info.scale = scale;

    return 0;
}

int SegmentAnything::embed_masks(const prompt_info_t& prompt_info, ncnn::Mat& mask_input, ncnn::Mat& has_mask)
{
    
    
    mask_input = ncnn::Mat(256, 256, 1);
    mask_input.fill(0.f);
    has_mask = ncnn::Mat(1);
    has_mask.fill(0.f);

    return 0;
}
int SegmentAnything::transform_coords(const image_info_t& image_info, ncnn::Mat& point_coords)
{
    
    
    for(int h = 0; h < point_coords.h; ++h){
    
    
        float* ptr = point_coords.row(h);
        ptr[0] *= image_info.scale;
        ptr[1] *= image_info.scale;
    }

    return 0;
}
int SegmentAnything::embed_points(const prompt_info_t& prompt_info, std::vector<ncnn::Mat>& point_labels, ncnn::Mat& point_coords)
{
    
    
    int num_points = prompt_info.points.size() / 2;
    point_coords = ncnn::Mat(num_points * 2, (void*)prompt_info.points.data()).reshape(2, num_points).clone();

    ncnn::Mat point_labels1 = ncnn::Mat(256, num_points);
    ncnn::Mat point_labels2 = ncnn::Mat(256, num_points);
    ncnn::Mat point_labels3 = ncnn::Mat(256, num_points);
    ncnn::Mat point_labels4 = ncnn::Mat(256, num_points);
    ncnn::Mat point_labels5 = ncnn::Mat(256, num_points);
    ncnn::Mat point_labels6 = ncnn::Mat(256, num_points);

    point_labels1.row_range(0, num_points - 1).fill(1.f);
    point_labels1.row_range(num_points - 1, 1).fill(0.f);

    for (int i = 0; i < num_points - 1; ++i) {
    
    
        if (prompt_info.labels[i] == -1)
            point_labels2.row_range(i, 1).fill(1.f);
        else
            point_labels2.row_range(i, 1).fill(0.f);
    }
    point_labels2.row_range(num_points - 1, 1).fill(1.f);

    for (int i = 0; i < num_points - 1; ++i) {
    
    
        if (prompt_info.labels[i] == 0)
            point_labels3.row_range(i, 1).fill(1.f);
        else
            point_labels3.row_range(i, 1).fill(0.f);
    }
    point_labels3.row_range(num_points - 1, 1).fill(0.f);

    for (int i = 0; i < num_points - 1; ++i) {
    
    
        if (prompt_info.labels[i] == 1)
            point_labels4.row_range(i, 1).fill(1.f);
        else
            point_labels4.row_range(i, 1).fill(0.f);
    }
    point_labels4.row_range(num_points - 1, 1).fill(0.f);

    for (int i = 0; i < num_points - 1; ++i) {
    
    
        if (prompt_info.labels[i] == 2)
            point_labels5.row_range(i, 1).fill(1.f);
        else
            point_labels5.row_range(i, 1).fill(0.f);
    }
    point_labels5.row_range(num_points - 1, 1).fill(0.f);

    for (int i = 0; i < num_points - 1; ++i) {
    
    
        if (prompt_info.labels[i] == 3)
            point_labels6.row_range(i, 1).fill(1.f);
        else
            point_labels6.row_range(i, 1).fill(0.f);
    }
    point_labels6.row_range(num_points - 1, 1).fill(0.f);

    point_labels.push_back(point_labels1);
    point_labels.push_back(point_labels2);
    point_labels.push_back(point_labels3);
    point_labels.push_back(point_labels4);
    point_labels.push_back(point_labels5);
    point_labels.push_back(point_labels6);

    return 0;
}
int SegmentAnything::MaskDecoder(const ncnn::Mat& image_embeddings, image_info_t& image_info, 
    const prompt_info_t& prompt_info, std::vector<sam_result_t>& sam_results, float pred_iou_thresh, float stability_score_thresh)
{
    
    
    std::vector<ncnn::Mat> point_labels;
    ncnn::Mat point_coords;
    embed_points(prompt_info, point_labels, point_coords);

    transform_coords(image_info, point_coords);

    ncnn::Mat mask_input, has_mask;
    embed_masks(prompt_info, mask_input, has_mask);

    ncnn::Extractor mask_decoder_ex = mask_decoder_net_.create_extractor();
    mask_decoder_ex.input("mask_input", mask_input);
    mask_decoder_ex.input("point_coords", point_coords);
    mask_decoder_ex.input("point_labels1", point_labels[0]);
    mask_decoder_ex.input("point_labels2", point_labels[1]);
    mask_decoder_ex.input("point_labels3", point_labels[2]);
    mask_decoder_ex.input("point_labels4", point_labels[3]);
    mask_decoder_ex.input("point_labels5", point_labels[4]);
    mask_decoder_ex.input("point_labels6", point_labels[5]);
    mask_decoder_ex.input("image_embeddings", image_embeddings);
    mask_decoder_ex.input("has_mask_input", has_mask);

    ncnn::Mat scores;
    mask_decoder_ex.extract("scores", scores);

    ncnn::Mat masks;
    mask_decoder_ex.extract("masks", masks);

    //postprocess
    std::vector<std::pair<float, int>> scores_vec;
    for (int i = 1; i < scores.w; ++i) {
    
    
        scores_vec.push_back(std::pair<float, int>(scores[i], i));
    }

    std::sort(scores_vec.begin(), scores_vec.end(), std::greater<std::pair<float, int>>());

    if (scores_vec[0].first > pred_iou_thresh) {
    
    
        sam_result_t sam_result;
        ncnn::Mat mask = masks.channel(scores_vec[0].second);
        cv::Mat cv_mask_32f = cv::Mat::zeros(cv::Size(mask.w, mask.h), CV_32F);
        std::copy((float*)mask.data, (float*)mask.data + mask.w * mask.h, (float*)cv_mask_32f.data);
        
        cv::Mat single_mask_32f;
        cv::resize(cv_mask_32f(cv::Rect(0, 0, image_info.pad_w, image_info.pad_h)), single_mask_32f, cv::Size(image_info.img_w,image_info.img_h), 0, 0, 1);

        float stable_score = calculate_stability_score(single_mask_32f);
        if (stable_score < stability_score_thresh)
            return -1;

        single_mask_32f = single_mask_32f > 0;
        single_mask_32f.convertTo(sam_result.mask, CV_8UC1, 1, 0);
        
        if (postprocess_mask(sam_result.mask, sam_result.box) < 0)
            return -1;

        sam_results.push_back(sam_result);
    }
    else {
    
    
        return -1;
    }

    return 0;
}
int SegmentAnything::postprocess_mask(cv::Mat& mask, cv::Rect& box)
{
    
    
    std::vector<std::vector<cv::Point>> contours;
    std::vector<cv::Vec4i> hierarchy;
    cv::findContours(mask.clone(), contours, hierarchy, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
    if(contours.size() == 0)
        return -1;

    if (contours.size() > 1) {
    
    
        float max_area = 0;
        int max_idx = 0;
        std::vector<std::pair<float,int>> areas;
        for (size_t i = 0; i < contours.size(); ++i) {
    
    
            float area = cv::contourArea(contours[i]);
            if (area > max_area) {
    
    
                max_idx = i;
                max_area = area;
            }
            areas.push_back(std::pair<float,int>(area,i));
        }
        
        for (size_t i = 0; i < areas.size(); ++i) {
    
    
            //if (i == max_idx)
            //    continue;
            //else {
    
    
            //    cv::drawContours(mask, contours, i, cv::Scalar(0), -1);
            //}
            if(areas[i].first < max_area * 0.3){
    
    
                cv::drawContours(mask, contours, i, cv::Scalar(0), -1);
            }
            else{
    
    
                box = box | cv::boundingRect(contours[i]);
            }
        }
    }
    else {
    
    
        box = cv::boundingRect(contours[0]);
    }
    return 0;
}
float SegmentAnything::calculate_stability_score(cv::Mat& mask, float mask_threshold, float stable_score_offset)
{
    
    
    float intersections = (float)cv::countNonZero(mask > (mask_threshold + stable_score_offset));
    float unions = (float)cv::countNonZero(mask > (mask_threshold - stable_score_offset));
    
    return intersections / unions;
}
}

2. 交互方法

分割交互方式中有好四种,开放式点,可以多个点组合, 矩形框, 分割一切,还有文字提示这几种方式。但文字提示效果不太稳定,C++代码没有实现这一部分。

开放式点
点击要分割目标的中间,分割包含该点的物体,会按最小分割的结果展示出来,如果想分割的物体大于展示的结果,可以在物体的其他部分也点击下:
在这里插入图片描述
选择矩形框
使用鼠标拖动在目标选择,分割目标:
在这里插入图片描述

分割一切
将图片中所有物体的分割都展示出来:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/matt45m/article/details/132137836