目标检测网络中大框包含小框的问题,非极大值抑制算法(Non-Maximum Suppression,NMS)优化

目标检测系列文章


前言

NMS一直是Object Detection很难绕开的步骤,本人的目标检测网络中遇到一个目标检测输出大框包含小框的问题如下图,(大概率是数据没搞好,但是数据多很难查出是哪写数据)。于是对nms做了一个修改。此方法可直接加在任何目标检测网络原版nms后面
在这里插入图片描述
这种情况下原版nms是无法过滤掉目标内的小框的。

一、NMS介绍

NMS 算法的大致过程:每轮选取置信度最大的 Bounding Box(简称 BBox,有时也会看到用 Pc,Possible Candidates 代替讲解的) ,接着关注所有剩下的 BBox 中与选取的 BBox 有着高重叠(IOU)的,它们将在这一轮被抑制。这一轮选取的 BBox 会被保留输出,且不会在下一轮出现。接着开始下一轮,重复上述过程:选取置信度最大 BBox ,抑制高 IOU BBox。
NMS 算法流程:这是一般文章中介绍的 NMS,比较难懂。但实际上 NMS 的实现反而简单很多。
在这里插入图片描述

二、NMS优化小框代码实现

这种方法必须外面加一层循环每类单独进行nms,才不会吧不同类别的框过滤掉

#include <iostream>
#include <vector>
#include <opencv2/opencv.hpp>

using namespace std;
using namespace cv;

static void sort(int n, const vector<float> x, vector<int> indices)
{
    
    
	int i, j;
	for (i = 0; i < n; i++)
		for (j = i + 1; j < n; j++)
		{
    
    
			if (x[indices[j]] > x[indices[i]])
			{
    
    
				//float x_tmp = x[i];
				int index_tmp = indices[i];
				//x[i] = x[j];
				indices[i] = indices[j];
				//x[j] = x_tmp;
				indices[j] = index_tmp;
			}
		}
}

int nonMaximumSuppression(int numBoxes, const vector<Point> points, const vector<Point> oppositePoints,
	const vector<float> score, float overlapThreshold, int& numBoxesOut, vector<Point>& pointsOut,
	vector<Point>& oppositePointsOut, vector<float> scoreOut)
{
    
    
	// 实现检测出的矩形窗口的非极大值抑制nms
	// numBoxes:窗口数目// points:窗口左上角坐标点// oppositePoints:窗口右下角坐标点// score:窗口得分
	// overlapThreshold:重叠阈值控制// numBoxesOut:输出窗口数目// pointsOut:输出窗口左上角坐标点
	// oppositePointsOut:输出窗口右下角坐标点// scoreOut:输出窗口得分
	int i, j, index;
	vector<float> box_area(numBoxes);				// 定义窗口面积变量并分配空间 
	vector<int> indices(numBoxes);					// 定义窗口索引并分配空间 
	vector<int> is_suppressed(numBoxes);			// 定义是否抑制表标志并分配空间 
													// 初始化indices、is_supperssed、box_area信息 
	for (i = 0; i < numBoxes; i++)
	{
    
    
		indices[i] = i;
		is_suppressed[i] = 0;
		box_area[i] = (float)((oppositePoints[i].x - points[i].x + 1) *(oppositePoints[i].y - points[i].y + 1));
	}
	// 对输入窗口按照分数比值进行排序,排序后的编号放在indices中 
	sort(numBoxes, score, indices);
	for (i = 0; i < numBoxes; i++)                // 循环所有窗口 
	{
    
    
		if (!is_suppressed[indices[i]])           // 判断窗口是否被抑制 
		{
    
    
			for (j = i + 1; j < numBoxes; j++)    // 循环当前窗口之后的窗口 
			{
    
    
				if (!is_suppressed[indices[j]])   // 判断窗口是否被抑制 
				{
    
    
					int x1max = max(points[indices[i]].x, points[indices[j]].x);                     // 求两个窗口左上角x坐标最大值 
					int x2min = min(oppositePoints[indices[i]].x, oppositePoints[indices[j]].x);     // 求两个窗口右下角x坐标最小值 
					int y1max = max(points[indices[i]].y, points[indices[j]].y);                     // 求两个窗口左上角y坐标最大值 
					int y2min = min(oppositePoints[indices[i]].y, oppositePoints[indices[j]].y);     // 求两个窗口右下角y坐标最小值 
					int overlapWidth = x2min - x1max;     // 计算两矩形重叠的宽度 
					int overlapHeight = y2min - y1max;     // 计算两矩形重叠的高度 
					float overlap = overlapWidth * overlapHeight;
					float area_i = (oppositePoints[indices[i]].x - points[indices[i]].x) * (oppositePoints[indices[i]].y - points[indices[i]].y);
					float area_j = (oppositePoints[indices[j]].x - points[indices[j]].x) * (oppositePoints[indices[j]].y - points[indices[j]].y);
					if (overlapWidth > 0 && overlapHeight > 0)
					{
    
    
						//float overlapPart = (overlapWidth * overlapHeight) / box_area[indices[j]];    // 计算重叠的比率 
						float overlapPart = overlap / (area_i + area_j - overlap);
						//if (overlapPart > overlapThreshold)   // 判断重叠比率是否超过重叠阈值 
						//{
    
    
						//	is_suppressed[indices[j]] = 1;     // 将窗口j标记为抑制 
						//}
						/*if (overlap == min(area_i, area_j))  //只会过滤掉大框内部小框
						{
							is_suppressed[indices[j]] = 1;   
						}*/
						if (area_i > area_j)   //能过滤掉大框内部小框,或者是上下超界的小框
						{
    
    
							if (overlapWidth == (oppositePoints[indices[j]].x - points[indices[j]].x))
							{
    
    
								is_suppressed[indices[j]] = 1;
							}
						}
						if (area_j > area_i)   //能过滤掉大框内部小框,或者是上下超界的小框
						{
    
    
							if (overlapWidth == (oppositePoints[indices[i]].x - points[indices[i]].x))
							{
    
    
								is_suppressed[indices[j]] = 1;
							}
						}

					}
				}
			}
		}
	}

	numBoxesOut = 0;    // 初始化输出窗口数目0 
	for (i = 0; i < numBoxes; i++)
	{
    
    
		if (!is_suppressed[i]) numBoxesOut++;    // 统计输出窗口数目 
	}
	index = 0;
	for (i = 0; i < numBoxes; i++)            // 遍历所有输入窗口 
	{
    
    
		if (!is_suppressed[indices[i]])       // 将未发生抑制的窗口信息保存到输出信息中 
		{
    
    
			pointsOut.push_back(Point(points[indices[i]].x, points[indices[i]].y));
			oppositePointsOut.push_back(Point(oppositePoints[indices[i]].x, oppositePoints[indices[i]].y));
			scoreOut.push_back(score[indices[i]]);
			index++;
		}

	}

	return true;
}

int main()
{
    
    
	Mat image = Mat::zeros(600, 600, CV_8UC3);
	int numBoxes = 4;
	vector<Point> points(numBoxes);
	vector<Point> oppositePoints(numBoxes);
	vector<float> score(numBoxes);

	points[0] = Point(100, 100); oppositePoints[0] = Point(500, 500); score[0] = 0.99;
	points[1] = Point(300, 110); oppositePoints[1] = Point(350, 480); score[1] = 0.9;
	points[2] = Point(220, 220); oppositePoints[2] = Point(550, 550); score[2] = 0.98;
	points[3] = Point(120, 90); oppositePoints[3] = Point(200, 510); score[3] = 0.98;

	float overlapThreshold = 0.3;
	int numBoxesOut;
	vector<Point> pointsOut;
	vector<Point> oppositePointsOut;
	vector<float> scoreOut;

	nonMaximumSuppression(numBoxes, points, oppositePoints, score, overlapThreshold, numBoxesOut, pointsOut, oppositePointsOut, scoreOut);
	for (int i = 0; i<numBoxes; i++)
	{
    
    
		rectangle(image, points[i], oppositePoints[i], Scalar(0, 255, 255), 6);
		char text[20];
		//_cprintf(text, "%f", score[i]);
		//putText(image, score[i].c_str(), points[i], FONT_HERSHEY_COMPLEX, 1, Scalar(0, 255, 255));
	}
	cout << numBoxesOut << endl;
	for (int i = 0; i<numBoxesOut; i++)
	{
    
    
		rectangle(image, pointsOut[i], oppositePointsOut[i], Scalar(0, 0, 255), 2);
	}

	imshow("result", image);

	waitKey();
	return 0;
}

如下图,没有画红线框极为已经过滤掉了
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/zengwubbb/article/details/119987499