Spark-MLlib的快速使用之三(随机森林)

(1)描述信息

随机森林算法是机器学习、计算机视觉等领域内应用极为广泛的一个算法,它不仅可以用来做分类,也可用来做回归即预测,随机森林机由多个决策树构成,相比于单个决策树算法,它分类、预测效果更好,不容易出现过度拟合的情况。

随机森林算法基于决策树,在正式讲解随机森林算法之前,先来介绍决策树的原理。决策树是数据挖掘与机器学习领域中一种非常重要的分类器,算法通过训练数据来构建一棵用于分类的树,从而对未知数据进行高效分类。举个相亲的例子来说明什么是决策树、如何构建一个决策树及如何利用决策树进行分类,某相亲网站通过调查相亲历史数据发现,女孩在实际相亲时有如下表现:

(2)测试数据

1 125:145 126:255 127:211 128:31 152:32 153:237 154:253 155:252 156:71 180:11 181:175 182:253 183:252 184:71 209:144 210:253 211:252 212:71 236:16 237:191 238:253 239:252 240:71 264:26 265:221 266:253 267:252 268:124 269:31 293:125 294:253 295:252 296:252 297:108 322:253 323:252 324:252 325:108 350:255 351:253 352:253 353:108 378:253 379:252 380:252 381:108 406:253 407:252 408:252 409:108 434:253 435:252 436:252 437:108 462:255 463:253 464:253 465:170 490:253 491:252 492:252 493:252 494:42 518:149 519:252 520:252 521:252 522:144 546:109 547:252 548:252 549:252 550:144 575:218 576:253 577:253 578:255 579:35 603:175 604:252 605:252 606:253 607:35 631:73 632:252 633:252 634:253 635:35 659:31 660:211 661:252 662:253 663:35

(3)测试代码

public static void main(String[] args) {

// $example on$

SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(sparkConf);

// Load and parse the data file.

String datapath = "sample_libsvm_data.txt";

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

// Split the data into training and test sets (30% held out for testing)

JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});

JavaRDD<LabeledPoint> trainingData = splits[0];

JavaRDD<LabeledPoint> testData = splits[1];

// Train a RandomForest model.

// Empty categoricalFeaturesInfo indicates all features are continuous.

Integer numClasses = 2;

HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();

Integer numTrees = 3; // Use more in practice.

String featureSubsetStrategy = "auto"; // Let the algorithm choose.

String impurity = "gini";

Integer maxDepth = 5;

Integer maxBins = 32;

Integer seed = 12345;

final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,

categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,

seed);

// Evaluate model on test instances and compute test error

JavaPairRDD<Double, Double> predictionAndLabel =

testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {

@Override

public Tuple2<Double, Double> call(LabeledPoint p) {

return new Tuple2<Double, Double>(model.predict(p.features()), p.label());

}

});

System.out.println("----------->" + predictionAndLabel.take(10));

Double testErr =

1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {

@Override

public Boolean call(Tuple2<Double, Double> pl) {

return !pl._1().equals(pl._2());

}

}).count() / testData.count();

System.out.println("Test Error: " + testErr);

System.out.println("Learned classification forest model:\n" + model.toDebugString());

// Save and load model

model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel");

RandomForestModel sameModel = RandomForestModel.load(jsc.sc(),

"target/tmp/myRandomForestClassificationModel");

// $example off$

}

(4)测试结果

[(1.0,1.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (1.0,1.0), (0.0,0.0)]

猜你喜欢

转载自blog.csdn.net/tbb_1984/article/details/84138908