Weka集成分类器

Weka集成分类器
package cn.edu.xmu.bdm.wekainjava.test;
import java.io.File;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.functions.LibSVM;
import weka.classifiers.meta.Vote;
import weka.core.Instances;
import weka.core.SelectedTag;
import cn.edu.xmu.bdm.wekainjava.utils.WekaFactory;
import cn.edu.xmu.bdm.wekainjava.utils.WekaFactoryImpl;
public class EnsembleTest {
public static void main(String[] args) throws Exception {
// LibSVM classifier = new LibSVM();
File trainFile = new File(
"C://Program Files//Weka-3-6//data//segment-challenge.arff");
File testFile = new File(
"C://Program Files//Weka-3-6//data//segment-test.arff");
/**
* 1. 获取weka工厂类
*/
WekaFactory wi = WekaFactoryImpl.getInstance();

/**
* 3. 从工厂中获取训练样本和测试样本实例
*/
Instances instancesTrain = wi.getInstance(trainFile);
instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1);
Instances instancesTest = wi.getInstance(testFile);
instancesTest.setClassIndex(instancesTest.numAttributes() - 1);

/**
* 2. 从工厂中获取分类器 具体使用哪一种特定的分类器可以选择 这样就构建了一个简单的分类器
*/
Classifier j48 = (Classifier) wi.getClassifier(LibSVM.class);
Classifier naiveBayes = (Classifier)wi.getClassifier(NaiveBayes.class);
Classifier libSVM = (Classifier)wi.getClassifier(LibSVM.class);
/**
* 2.1 设置集成分类器
*/
Classifier[] cfsArray = new Classifier[3];
cfsArray[0] = j48;
cfsArray[1] = naiveBayes;
cfsArray[2] = libSVM;

/**
* 2.2 定制集成分类器的决策方式
* AVERAGE_RULE
* PRODUCT_RULE
* MAJORITY_VOTING_RULE
* MIN_RULE
* MAX_RULE
* MEDIAN_RULE
* 它们具体的工作方式,参考weka的说明文档。
* 通常情况下选择的是多数投票的决策规则
*/
Vote ensemble = new Vote();
SelectedTag tag = new SelectedTag(Vote.MAJORITY_VOTING_RULE, Vote.TAGS_RULES);
ensemble.setCombinationRule(tag);
ensemble.setClassifiers(cfsArray);
//设置随机数种子
ensemble.setSeed(2);
//训练ensemble分类器
ensemble.buildClassifier(instancesTrain);
/**
* 5. 从工厂中获取使用Evaluation,测试样本测试分类器的学习效果
*/
double sum = instancesTrain.numInstances();
Evaluation testingEvaluation = wi.getEvaluation(ensemble, instancesTest);
int length = instancesTest.numInstances();
for (int i = 0; i < length; i++) {
// 通过这个方法来用每个测试样本测试分类器的效果
testingEvaluation.evaluateModelOnceAndRecordPrediction(ensemble,
instancesTest.instance(i));
}
// double[][] confusionMatrix = testingEvaluation.confusionMatrix();
// for (int i = 0; i < confusionMatrix.length; i++) {
// double[] ds = confusionMatrix[i];
// for (int j = 0; j < ds.length; j++) {
// System.out.print(ds[j]);
// }
// System.out.println();
// }
System.out.println(testingEvaluation.toSummaryString());
System.out.println(testingEvaluation.toMatrixString());
System.out.println(testingEvaluation.toClassDetailsString());
// System.out.println(testingEvaluation.toCumulativeMarginDistributionString());
System.out.println("分类器的正确率:" + (1 - testingEvaluation.errorRate()));

}
}

猜你喜欢

转载自xuchenglang.iteye.com/blog/1973069
今日推荐