机器学习聚类算法:DBSCAN 对鸢尾花数据分类 C++实现

C++实现DBSCAN算法

该算法的原理主要是先找出每个数据邻域内并且数据数量大于给给定值的数据作为核心数据,然后从任一核心数据触发找到所有的密度可达点,将这些密度可达点设置为一个簇,直到所有的核心数据被遍历过为止,数据集用的是python的鸢尾花数据,接下来直接给出代码:

//DataPoint.h  储存每个数据点
#ifndef _DATA_POINT_
#define  _DATA_POIYT_

#include <string>
#include <iostream>
#include<vector>
using namespace std;

class DataPoint
{
public:
	DataPoint() {}
	DataPoint(unsigned long dpID, vector<double> data_insert, bool isKey);
	bool GetKey(); //获取是否为核心对象
	bool GetIsVisit();//获取是否访问过
	unsigned long GetDpId();//获取数据id
	long GetClusterId(); //获取聚类id
	long GetDataNum(); //获取数据数量
	string GetName(); //获取名字

	void SetKey(bool s);  //设置核心对象
	void SetIsVisit(bool s);  //设置是否访问过
	void SetDpid(unsigned long s); //设置数据点id
	void SetClusterId(long s); //设置聚类id
	void SetDataNum(long s); //设置数据长度
	vector<double> GetInsertData(); //获取数据

	void insert_data(vector<double> v);//加入数据
	void insert_name(string n); //加入花卉名字

	vector<int> InRangeData2; //在邻域范围内的数据

	~DataPoint() {}

private:
	unsigned long dpID;                //数据点ID
	long clusterId;                    //所属聚类ID
	vector<double> data_insert;        //所保存的数据
	bool isKey;                        //是否核心对象
	bool visited;                    //是否已访问
	string name;                     //花卉名字
	long  Data_Num;                 //数据长度
};
#endif // !_DATA_POINT_
//DataPoint.cpp
#include "DataPoint.h"
DataPoint::DataPoint(unsigned long dpID,vector<double> Data, bool isKey):isKey(isKey),dpID(dpID)
{
	for (int i = 0; i < Data.size(); i++)
	{
		this->data_insert.push_back(Data[i]);
	}
}

long DataPoint::GetClusterId()
{
	return this->clusterId;
}
unsigned long DataPoint::GetDpId()
{
	return this->dpID;
}
bool DataPoint::GetKey()
{
	return this->isKey;
}
bool DataPoint::GetIsVisit()
{
	return this->visited;
}
long DataPoint::GetDataNum()
{
	return this->Data_Num;
}
vector<double> DataPoint::GetInsertData()
{
	return this->data_insert;
}
string DataPoint::GetName()
{
	return this->name;
}

void DataPoint::insert_data(vector<double> v)
{
	for (int i = 0; i < v.size(); i++)
	{
		(this->data_insert).push_back(v[i]);
	}
}
void DataPoint::insert_name(string n)
{
	this->name = n;
}


void DataPoint::SetKey(bool s)
{
	this->isKey = s;
}
void DataPoint::SetIsVisit(bool s)
{
	this->visited = s;
}
void DataPoint::SetClusterId(long s)
{
	this->clusterId = s;
}
void DataPoint::SetDpid(unsigned long s)
{
	this->dpID = s;
}
void DataPoint::SetDataNum(long s)
{
	this->Data_Num = s;
}
//DBSCAN.h
#ifndef _DBSCAN_H_
#define  _DBSCAN_H_
#include <iostream>
#include<fstream>
#include<sstream>
#include<string>
#include<vector>
#include"DataPoint.h"

using namespace std;
//将string类型转换为数字
template <class Type>
Type stringToNum(const string& str)
{
	istringstream iss(str);
	Type num;
	iss >> num;
	return num;
}

class DBSCAN
{
public:
	DBSCAN(double radius, unsigned int minPTs);
	~DBSCAN() {}
	void GetData(); //提取数据
	void StartDBSCAN(); //开始算法
	void print();//打印信息
private:
	
	double GetDistance(DataPoint &point,DataPoint &point2); //获取两个数据之间的距离
	void DFS_Find_Cluster(unsigned long dpID,int ClusterNum); //DFS搜索密度可达数据
	vector<DataPoint> DataBase; //保存所有数据
	vector<DataPoint> CoreData;//保存核心数据
	double radius;                    //半径
	unsigned int dataNum;            //数据数量
	unsigned int minPTs;            //邻域最小数据个数
	int ClusterNum = 1;
};
#endif // !_DBSCAN_H_
//DBSCAN.cpp
#include "DBSCAN.h"
//得到两个向量之间的距离
double DBSCAN::GetDistance(DataPoint &point, DataPoint &point2)
{
	double sum = 0;
	for (int i = 0; i < point.GetDataNum(); i++)
	{
		sum += (point.GetInsertData()[i] - point2.GetInsertData()[i])*(point.GetInsertData()[i] - point2.GetInsertData()[i]);
	}
	double result = sqrt(sum);
	return result;
}
//获取数据
void DBSCAN::GetData()
{
	ifstream file;
	string line;
	file.open("iris.csv", ios::in);
	if (file.fail()) {
		cout << "文件打开失败" << endl;
		return;
	}
	while (getline(file, line))
	{
		stringstream ss(line);
		string str;
		vector<string>v;
		vector<double> d;
		DataPoint temp;
		while (getline(ss, str, ','))
		{
			v.push_back(str);
		}
		for (int i = 1; i < 5; i++)
		{
			d.push_back(stringToNum<double>(v[i]));
		}
		temp.insert_data(d);
		temp.insert_name(v[5]);
		temp.SetDpid(dataNum);
		temp.SetClusterId(-1);
		temp.SetIsVisit(false);
		temp.SetKey(false);
		temp.SetDataNum(4);
		DataBase.push_back(temp);
		dataNum++;
	}
	file.close();
	//取样本数据集中距离不大于radius的数据并且这些数据数量大于minPTs的数据
	for (int i = 0; i < dataNum; i++)
	{
		DataPoint &temp = DataBase[i];
		for (int j = 0; j < dataNum; j++)
		{
			double dis = GetDistance(temp, DataBase[j]);
			if (dis <= radius && DataBase[j].GetDpId() != i)
			{
				temp.InRangeData2.push_back(j);
			}
		}
		if (temp.InRangeData2.size() >= minPTs)
		{
			temp.SetKey(true);
			CoreData.push_back(temp);
		}
		else
		{
			temp.SetKey(false);
		}
	}
}

void DBSCAN::StartDBSCAN()
{
	for (int i = 0; i < dataNum; i++)
	{
		if (DataBase[i].GetKey() == true && DataBase[i].GetIsVisit() == false)
		{
			DataBase[i].SetIsVisit(true);
			DataBase[i].SetClusterId(ClusterNum);
			DFS_Find_Cluster(DataBase[i].GetDpId(), ClusterNum); //DFS搜索数据
			ClusterNum++;
		}
	}
}
void DBSCAN::DFS_Find_Cluster(unsigned long dpID, int ClusterNum)
{
	DataPoint &point = DataBase[dpID];
	if (!point.GetKey()) return;
	for (int i = 0; i < point.InRangeData2.size(); i++)
	{
		int temp = point.InRangeData2[i];
		if (DataBase[temp].GetIsVisit() == false)
		{
			DataBase[temp].SetIsVisit(true);
			DataBase[temp].SetClusterId(ClusterNum);
			if (DataBase[temp].GetKey() == true)
			{
				DFS_Find_Cluster(DataBase[temp].GetDpId(), ClusterNum); //递归搜索
			}
		}
	}
}
void DBSCAN::print()
{
	cout << ClusterNum-1 << endl;
	for (int i = 1; i < ClusterNum; i++)
	{
		cout << "聚类" << i << endl;
		for (int j = 0; j < dataNum; j++)
		{
			if (DataBase[j].GetClusterId() == i)
			{
				cout << "数据编号:" << j << " ";
				for (int s = 0; s < 4; s++)
				{
					cout << DataBase[j].GetInsertData()[s] << " ";
				}
				cout << DataBase[j].GetName();
				cout << endl;
			}
		}
	}
}
DBSCAN::DBSCAN(double radius, unsigned int minPTs):radius(radius),minPTs(minPTs)
{

}
//main.cpp
#include "DBSCAN.h"
using namespace std;
int main()
{
	DBSCAN d(0.4, 2);
	d.GetData();
	d.StartDBSCAN();
	d.print();
}

因为DBSACAN算法只对核心数据经行遍历搜索密度可达点,噪声点并不会存在于结果中,所以最终可能会有一些数据缺失。

发布了5 篇原创文章 · 获赞 0 · 访问量 84

猜你喜欢

转载自blog.csdn.net/weixin_43791996/article/details/105586289