c++复现朴素贝叶斯算法

C++复现朴素贝叶斯代码

#include<iostream>
#include<vector>
#include<algorithm>
#include<string>
#include<sstream>
#include<fstream>
#include<map>
#include <numeric>
#include <windows.h>
using namespace std;

#define pi 3.14159265
template <class Type>
Type stringToNum(const string& str)
{
	istringstream iss(str);
	Type num;
	iss >> num;
	return num;
}

class BeiYeSi
{
public:
	BeiYeSi() {};
	void GetData();
	map<string ,int> count_list(vector<string> v);
	map<string, double>discrete_p(map<string,int> mmp,int num);
	double mu_of_list(vector<double> v);
	double var_of_list(vector<double> v,double mu);
	void train();
	void predict(vector<string> v);
	double continuous_p(double num, double mu, double var);

private:
	vector<double> class_p;

	vector<vector<double>> good_con;
	vector<vector<double>> bad_con;
	vector<vector<string>> good_discrete;
	vector<vector<string>> bad_discrete;
	vector<vector<string>> AllData;

	vector<map<string,double>> discrete_attris_with_good_p;
	vector<map<string,double>> discrete_attris_with_bad_p;
	

	ifstream read_csv;


	double len;
	double len_pos;
	double len_neg;

	vector<double> good_mus, good_vars, bad_mus, bad_vars;
};
map<string,int> BeiYeSi::count_list(vector<string> v)
{
	map<string, int> mmp;
	map<string, int>::iterator it;
	for (int i = 0; i < v.size(); i++)
	{
		it = mmp.find(v[i]);
		if (it != mmp.end())
		{
			mmp[v[i]]++;
		}
		else
		{
			mmp[v[i]] = 1;
		}
	}
	return mmp;
}
void BeiYeSi::GetData()
{
	string line;
	read_csv.open("a.txt",ios::in);
	if (read_csv.fail())
	{
		cout << "open files error" << endl;
		return;
	}
	while (getline(read_csv, line))
	{
		int l = line.length();
		stringstream ss(line);
		string str;
		string s = "是";
		string s1 = "否";
		vector<string> v;
		vector<string> v1;
		vector<double> v2;
		while (getline(ss, str, '	'))
		{
			v.push_back(str);
		}
		AllData.push_back(v);
		//for (int i = 0; i < v.size(); i++)
		//{
		//	cout << v[i] << " ";
		//}
		//cout << endl;
		if (v[v.size() - 1] == s)
		{
			for (int j = 1; j <= 6; j++)
			{
				//cout << v[j] << endl;
				v1.push_back(v[j]);
			}
			good_discrete.push_back(v1);
			v1.clear();
			for (int j = 7; j <= 8; j++)
			{
				double k = stringToNum<double>(v[j]);
				v2.push_back(k);
			}
			good_con.push_back(v2);
		}
		if (v[v.size() - 1] == s1)
		{
			for (int j = 1; j <= 6; j++)
			{
				//cout << v[j] << endl;
				v1.push_back(v[j]);
			}
			bad_discrete.push_back(v1);
			v1.clear();
			for (int j = 7; j <= 8; j++)
			{
				double k = stringToNum<double>(v[j]);
				v2.push_back(k);
			}
			bad_con.push_back(v2);
		}
	}
}
double BeiYeSi::mu_of_list(vector<double> v)
{
	double sum = 0;
	for (int i = 0; i < v.size(); i++)
	{
		sum += v[i];
	}
	double mu = sum / v.size();
	
	return mu;
}
double BeiYeSi::var_of_list(vector<double> v ,double mu)
{
	double var = 0;
	for (int i = 0; i < v.size(); i++)
	{
		var += ((v[i] - mu)*(v[i] - mu));
	}
	var = var / (double)(v.size()-1);
	
	return var;
}

map<string, double> BeiYeSi::discrete_p(map<string, int> mmp,int num)
{
	map<string, double> new_p;
	for (map<string, int>::iterator iter = mmp.begin(); iter != mmp.end(); iter++)
	{
		new_p[iter->first] = (float)iter->second / num;
	}
	return new_p;

}
void BeiYeSi::train()
{
	GetData();
	len = AllData.size();
	len_pos = good_discrete.size();
	len_neg = bad_discrete.size();

	class_p.push_back(double(len_pos / len));
	class_p.push_back(double(len_neg / len));

	for (int i = 0; i < 6;i++)
	{
		int good_var, good_mu;
		int bad_var, bad_mu;


		vector<string> attr_with_bad;
		vector<string> attr_with_good;

		map<string, int> unique_with_good;
		map<string, int> unique_with_bad;

		for (int j = 0; j < good_discrete.size(); j++)
		{
			attr_with_good.push_back(good_discrete[j][i]);
		}
		for (int j = 0; j < bad_discrete.size(); j++)
		{
			attr_with_bad.push_back(bad_discrete[j][i]);
		}

		unique_with_good = count_list(attr_with_good);
		unique_with_bad = count_list(attr_with_bad);

		discrete_attris_with_good_p.push_back(discrete_p(unique_with_good,good_discrete.size()));
		discrete_attris_with_bad_p.push_back(discrete_p(unique_with_bad,bad_discrete.size()));
	}

	

	for (int i = 0; i < 2; i++)
	{
		vector<double> attr_with_bad;
		vector<double> attr_with_good;
		
		for (int j = 0; j < bad_con.size(); j++)
		{
			
			attr_with_bad.push_back(bad_con[j][i]);
		}
		for (int j = 0; j < good_con.size(); j++)
		{
			attr_with_good.push_back(good_con[j][i]);
		}

		double good_mu = mu_of_list(attr_with_good);
		double bad_mu = mu_of_list(attr_with_bad);
		double good_var = var_of_list(attr_with_good, good_mu);
		double bad_var = var_of_list(attr_with_bad, bad_mu);


		good_mus .push_back(good_mu);
		bad_mus .push_back(bad_mu);
		good_vars.push_back(good_var);
		bad_vars.push_back(bad_var);
	}
}
double BeiYeSi::continuous_p(double num, double mu, double var)
{
	double  p = (1.0 / (sqrt(2 * pi) * sqrt(var))) * exp(-(((num - mu)*(num - mu)) / (2 * var)));
	return p;
}
void BeiYeSi::predict(vector<string> v)
{

	map<string, double> good_temp;
	map<string, double> bad_temp;

	double p_good = class_p[0];
	double p_bad = class_p[1];

	

	for (int i = 0; i < 6;i++)
	{
		good_temp = discrete_attris_with_good_p[i];
		bad_temp = discrete_attris_with_bad_p[i];
		
		p_good *= good_temp[v[i]];
		p_bad *= bad_temp[v[i]];
	}
	for (int i = 0; i < 2; i++)
	{
		double num = stringToNum<double>(v[i + 6]);
		p_good *= continuous_p(num, good_mus[i], good_vars[i]);
		p_bad *= continuous_p(num, bad_mus[i], bad_vars[i]);
	}
	if (p_good > p_bad)
	{
		cout << "(" << p_good << " , " << p_bad << ") " << "是" << endl;
	}
	else
	{
		cout << "(" << p_good << " , " << p_bad << ") " << "否" << endl;
	}
}
int main()
{
	DWORD Start = GetTickCount();
	BeiYeSi F;
	string s1, s2, s3, s4, s5, s6, s7,s8;
	s1 = "青绿";
	s2 = "蜷缩";
	s3 = "浊响";
	s4 = "清晰";
	s5 = "凹陷";
	s6 = "硬滑";
	s7 = "0.697";
	s8 = "0.460";
	vector<string> v;
	v.push_back(s1);
	v.push_back(s2);
	v.push_back(s3);
	v.push_back(s4);
	v.push_back(s5);
	v.push_back(s6);
	v.push_back(s7);
	v.push_back(s8);
	F.train();
	F.predict(v);
	DWORD End = GetTickCount();
	cout << Start - End << endl;
}

可能有些地方不太好,代码还需要更简洁,之后会阅读more effective c++改进。数据集是机器学习上的西瓜数据集3.0.

发布了5 篇原创文章 · 获赞 0 · 访问量 84

猜你喜欢

转载自blog.csdn.net/weixin_43791996/article/details/104863744