在OpenCV中实现决策树和随机森林

目录

1.决策树

2.随机森林


1.决策树

需要注意的点

     Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//文件名
		                                                    0,//第0行略过
		                                                    0,
		                                            1,//区间[0,1)为存储了响应列
		                                            "cat[0-22]",//0-22行均为类别数据
		                                            ',',//数据间的分隔符号为","
	                                                     '?');//丢失数据用"?"表示

1.数据类型有cat和ord之分,具体可以参阅统计数据定义:

  https://zhidao.baidu.com/question/1964314134743418500.html

2.默认的响应列的格式(第2-第3行)是前闭后开;

3.分割训练集和数据集时,数据集的顺序会大幅度的影响决策树的结果:

	data_set->setTrainTestSplitRatio(0.90, false);

4.对于概率权重的设置,你可以理解为对识别某一类物体具有相对更高的准确率(请注意我的矩阵初始化方法);

	float _priors[] = { 1.0,10.0 };
	Mat priors(1, 2, CV_32F, _priors);
	dtree->setPriors(priors);//为所有的答案设置权重

5.在OpenCV3.0以上的版本使用决策树与随机森林所继承的类都是RTrees,相比与DTrees而言,新的类RTrees能够处理数据集中的缺失数据,建模的唯一区别就是在生成随机森林时,需要设置树的终止生成条件,默认是100棵树:

 forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER 
                               + TermCriteria::EPS, 100, 0.01));//随机森林的终止标准

6.模型的存储与加载:

	dtree->save("dtree_01.xml");//保存
        Ptr<RTrees> dtree = RTrees::load("dtree_01.xml");//加载训练模型

7.如果你想用加载后的模型进行数据的分类和回归,请务必手动创建训练集、测试集和验证集;这是因为如果你的数据集里面含有字母类别,那么opencv会以默认的方式转化为ASCALL码并归一化,在这种情况下如果依旧使用默认的方式加载验证集,必然会在使用predict时程序崩溃。

决策树的训练代码:

毒蘑菇的数据集:https://github.com/oreillymedia/Learning-OpenCV-3_examples/tree/master/mushroom

#include<iostream>
#include<opencv2/opencv.hpp>

using namespace cv;
using namespace ml;
using namespace std;

//1.生成训练集结构体对象指针
	Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//文件名
		                                           0,//第0行略过
		                                           0,
		                                           1,//区间[0,1)为存储了响应列
		                                 "cat[0-22]",//0-22行均为类别数据
		                                         ',',//数据间的分隔符号为","
	                                                '?');//丢失数据用"?"表示

//2.验证数据读取的正确性
	int n_samples = data_set->getNSamples();
	if (n_samples == 0)
	{
		cerr << "Could not read file: mushroom.data" << endl;
		exit(-1);
	}
	else
	{
		cout << "Read " << n_samples << " samples from mushroom.data" << endl;
	}

	//3.分割训练集和测试集,比例为9:1,不打乱数据集的顺序
	data_set->setTrainTestSplitRatio(0.90, false);
	int n_train_samples = data_set->getNTrainSamples();
	int n_test_samples = data_set->getNTestSamples();
	Mat trainMat = data_set->getTrainSamples();
	//4.决策树
	//4.1 创建
	Ptr<RTrees> dtree = RTrees::create();
	//4.2 参数设置
	dtree->setMaxDepth(8); //树的最大深度
	dtree->setMinSampleCount(10); //节点样本数的最小值
	dtree->setRegressionAccuracy(0.01f);
	dtree->setUseSurrogates(false);//是否允许使用替代分叉点处理丢失的数据
	dtree->setMaxCategories(15);//决策树的最大预分类数量
	dtree->setCVFolds(0);//如果 CVFolds>1 那么就使用k-fold交叉修建决策树 其中k=CVFolds
	dtree->setUse1SERule(true);//True 表示使用更大力度的修剪,这会导致树的规模更小,
                                                           但准确性更差,用于解决过拟合问题
	dtree->setTruncatePrunedTree(true);//是否删掉被减枝的部分
	float _priors[] = { 1.0,10.0 };
	Mat priors(1, 2, CV_32F, _priors);
	dtree->setPriors(priors);//为所有的答案设置权重
	//4.3 训练
	dtree->train(data_set);
	//4.4 计算训练误差
	Mat results;
	float train_performance = dtree->calcError(data_set,
		false,//true 表示使用测试集  false 表示使用训练集
		results);
	//5 训练集的结果分析
	vector<String> names;
	data_set->getNames(names);
	Mat flags = data_set->getVarSymbolFlags();
	Mat expected_responses = data_set->getResponses();
	int good = 0, bad = 0, total = 0;
	for (int i = 0; i < data_set->getNTrainSamples(); ++i)
	{
		float received = results.at<float>(i, 0);
		float expected = expected_responses.at<float>(i, 0);
		String r_str = names[(int)received];
		String e_str = names[(int)expected];
		if (received != expected)
		{
		   bad++;
                   cout << "Expected: " << e_str << " ,got: " << r_str << endl;
		}
		else good++;
		total++;
	}
	cout << "Correct answers: " << (float(good) / total) << "% " << endl;
	cout << "Incorrect answers: " << (float(bad) / total) << "% " << endl;

	//6 测试集的结果分析
	float test_performance = dtree->calcError(data_set, true, results);
	cout << "Performance on training data: " << train_performance << "%" << endl;
	cout << "Performance on test data: " << test_performance << "%" << endl;

	//保存
	dtree->save("dtree_01.xml");

数据的预测:

Ptr<RTrees> dtree = RTrees::load("dtree_01.xml");
Mat sample = (Mat_<float>(1, 22)
          << 2, 3, 4, 5, 6, 7, 8, 2, 9, 1, 8, 10, 10, 4, 4, 11, 4, 12, 11, 9, 10, 13);
float result = dtree->predict(sample);

2.随机森林

使用随机森林需要注意的点:

1.随机森林的参数设置比较简单,不要要考虑数的减枝等因素,但计算变量的重要性会需要额外的计算时间:

	 Ptr<RTrees> forest_mushroom = RTrees::create();
	 forest_mushroom->setMaxDepth(10); //树的最大深度
	 forest_mushroom->setRegressionAccuracy(0.01f);//设置回归精度
	 forest_mushroom->setMinSampleCount(10);//节点的最小样本数量
	 forest_mushroom->setMaxCategories(15);//最大预分类数
	 forest_mushroom->setCalculateVarImportance(true);//计算变量的重要性
	 forest_mushroom->setActiveVarCount(4);//树节点随机选择特征子集的大小
                                               //终止标准
	 forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER + 
                                                    TermCriteria::EPS, 100, 0.01));

2.在数据的预测部分 与决策树有些微的不同,随机森林的预测有两种形式,第一种是直接基于投票结果给出响应值:


		 Mat sample = testSample.row(i);
		 float r = forest_mushroom->predict(sample);
		 r = fabs((float)r - testResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
		

第二种是使用getVotes计算票选矩阵(注意:所计算出来的票选矩阵已经提取了最大公约数)我们可以利用票选结果来计算概率;

                 Mat sample = testSample.row(i);
		 Mat result;
		 forest_mushroom->getVotes(sample,result,0);

3.随机森林的泛化能力很强,只需要较少的样本就可以实现很高精度的分类,比如在8400多个蘑菇样本中,我只使用了840样本生成随机森林,最后的分类准确率依然高达96%

4.加载、读取以及验证集的使用请参阅随机森林部分

建模代码:

    //建立蘑菇分类的随机森林
     //1.构建训练集和测试集
     Ptr<TrainData> data_set = TrainData::loadFromCSV("mushroom.data",//文件名
	                                               0,//第0行略过
	                                               0,
	                                               1,//区间[0,1)为存储了响应列
	                                     "cat[0-22]",//0-22行均为类别数据
	                                             ',',//数据间的分隔符号为","
	                                            '?');//丢失数据用"?"表示
     //2.验证数据读取的正确性
	 int n_samples = data_set->getNSamples();
	 if (n_samples == 0)
	 {
		 cerr << "Could not read file: mushroom.data" << endl;
		 exit(-1);
	 }
	 else
	 {
		 cout << "Read " << n_samples << " samples from mushroom.data" << endl;
	 }

	 //3.分割训练集和测试集,比例为9:1,打乱数据集的顺序
	 data_set->setTrainTestSplitRatio(0.90, true);
	 int n_train_samples = data_set->getNTrainSamples();
	 int n_test_samples = data_set->getNTestSamples();

	 //4.随机森林
	 Ptr<RTrees> forest_mushroom = RTrees::create();
	 forest_mushroom->setMaxDepth(10); //树的最大深度
	 forest_mushroom->setRegressionAccuracy(0.01f);//设置回归精度
	 forest_mushroom->setMinSampleCount(10);//节点的最小样本数量
	 forest_mushroom->setMaxCategories(15);//最大预分类数
	 forest_mushroom->setCalculateVarImportance(true);//计算变量的重要性
	 forest_mushroom->setActiveVarCount(4);//树节点随机选择特征子集的大小
	 forest_mushroom->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER +  
                                            TermCriteria::EPS, 100, 0.01));//终止标准
	 //训练模型
	 forest_mushroom->train(data_set);
	 //计算训练集和测试集的误差
	 float correct_Train_answer = 0;
	 float correct_Test_answer = 0;
	 //1.训练集
	 Mat trainSample = data_set->getTrainSamples();
	 Mat trainResMat = data_set->getTrainResponses();
	 for (int i = 0; i < trainSample.rows; i++)
	 {
		 Mat sample = trainSample.row(i);
		 float r = forest_mushroom->predict(sample);
		 r = fabs((float)r - trainResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
		 correct_Train_answer += r;
	 }
	 float r1 = correct_Train_answer / n_train_samples;
	 //2.测试集
	 Mat testSample =  data_set->getTestSamples();
	 Mat testResMat = data_set->getTestResponses();
	 for (int i = 0; i < testSample.rows; i++)
	 {
		 Mat sample = testSample.row(i);
		 float r = forest_mushroom->predict(sample);
		 r = fabs((float)r - testResMat.at<float>(i)) <= FLT_EPSILON ? 1 : 0;
		 correct_Test_answer += r;
	 }
	 float r2 = correct_Test_answer / n_test_samples;
	 //3.输出结果
	 cout << "trainSet Accuracy: " << r1* 100 << "%" << endl;
	 cout << "testSet Accuracy:  " << r2 * 100 << "%" << endl;
	 //4.保存模型
	 forest_mushroom->save("forest_mushroom.xml");

相关资料

1.OpenCV 3.4 官方ML库手册:

https://docs.opencv.org/3.4.0/dd/ded/group__ml.html

2.随机森林:

https://blog.csdn.net/akadiao/article/details/79413713

https://www.cnblogs.com/hrlnw/p/3850459.html

https://blog.csdn.net/akadiao/article/details/79413713

https://blog.csdn.net/wishchin/article/details/78662797

https://blog.csdn.net/wishchin/article/details/78662797

3.OpenCV作者Gray的Github关于毒蘑菇的数据集资料:

https://github.com/oreillymedia/Learning-OpenCV-3_examples/tree/master/mushroom

4.统计数据的定义:

https://zhidao.baidu.com/question/1964314134743418500.html

5.OpenCV机器学习博客:

https://www.cnblogs.com/denny402/p/5032232.html

6.ASCALL码对照表

https://blog.csdn.net/u011930916/article/details/79623922

7.OpenCV中Mat对象使用全解:

https://blog.csdn.net/guyuealian/article/details/70159660

8.OpenCV数据类型的位数总结:

https://blog.csdn.net/lcgwust/article/details/70770148

猜你喜欢

转载自blog.csdn.net/qiao_lili/article/details/83743840
今日推荐