Java Implementation of Classification Algorithms in Spark-Mllib

1. Brief description

  Spark is a very popular data analysis framework, and the machine learning package Mllib is also one of its many highlights. I believe that many people want to get started with spark as quickly as I do. Below I will list the concise code for implementing mllib classification. The code will briefly describe the structure of the training set and sample set, as well as the meaning of the parameters of each classification algorithm. Classification models include Naive Bayes, SVM, decision trees, and random forests.

 

2. Implementation code

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import java.util.LinkedList;
import java.util.List;
 
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
 
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
 
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.classification.SVMWithSGD;
 
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
 
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
 
public class test {
     public static void main(String[] arg){
        //生成spark对象
         SparkConf conf = new SparkConf();
         conf.set( "spark.testing.memory" , "2147480000" );  // spark的运行配置,意指占用内存2G
         JavaSparkContext sc = new JavaSparkContext( "local[*]" , "Spark" , conf);      //第一个参数为本地模式,[*]尽可能地获取多的cpu;第二个是spark应用程序名,可以任意取;第三个为配置文件
         
         //训练集生成
         LabeledPoint pos = new LabeledPoint( 1.0 , Vectors.dense( 2.0 , 3.0 , 3.0 )); //规定数据结构为LabeledPoint,1.0为类别标号,Vectors.dense(2.0, 3.0, 3.0)为特征向量
         LabeledPoint neg = new LabeledPoint( 0.0 , Vectors.sparse( 3 , new int [] { 2 , 1 , 1 }, new double [] { 1.0 , 1.0 , 1.0 })); //特征值稀疏时,利用sparse构建
        List l = new LinkedList(); //利用List存放训练样本
         l.add(neg);
         l.add(pos);
         JavaRDD<LabeledPoint>training = sc.parallelize(l); //RDD化,泛化类型为LabeledPoint 而不是List
         final NaiveBayesModel nb_model = NaiveBayes.train(training.rdd());       
         
         //测试集生成
         double []  d = { 1 , 1 , 2 };
         Vector v =  Vectors.dense(d); //测试对象为单个vector,或者是RDD化后的vector
 
         //朴素贝叶斯
       System.out.println(nb_model.predict(v)); // 分类结果
       System.out.println(nb_model.predictProbabilities(v)); // 计算概率值
 
       
       //支持向量机
       int numIterations = 100 ; //迭代次数
       final SVMModel svm_model = SVMWithSGD.train(training.rdd(), numIterations); //构建模型
       System.out.println(svm_model.predict(v));
 
       //决策树
       Integer numClasses = 2 ; //类别数量
       Map<Integer, Integer> categoricalFeaturesInfo = new HashMap();
       String impurity = "gini" ; //对于分类问题,我们可以用熵entropy或Gini来表示信息的无序程度 ,对于回归问题,我们用方差(Variance)来表示无序程度,方差越大,说明数据间差异越大
       Integer maxDepth = 5 ; //最大树深
       Integer maxBins = 32 ; //最大划分数
       final DecisionTreeModel tree_model = DecisionTree.trainClassifier(training, numClasses,categoricalFeaturesInfo, impurity, maxDepth, maxBins); //构建模型
       System.out.println( "决策树分类结果:" );  
       System.out.println(tree_model.predict(v));
       
       //随机森林
       Integer numTrees = 3 ; // Use more in practice.
       String featureSubsetStrategy = "auto" ; // Let the algorithm choose.
       Integer seed = 12345 ;
       // Train a RandomForest model.
       final RandomForestModel forest_model = RandomForest.trainRegressor(training,
         categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); //参数与决策数基本一致,除了seed
       System.out.println( "随机森林结果:" );  
       System.out.println(forest_model.predict(v));
     }
   }

三.注意

1.利用spark进行数据分析时,数据一般要转化为RDD(利用spark所提供接口读取外部文件,一般会自动转化为RDD,通过MapReduce处理同样可以产生与接口匹配的训练集)

2.训练样本统一为标签向量(LabelPoint)。样本集为List,但是转化为RDD时,数据类型却为JavaRDD<LabeledPoint>(模型训练时,接口只接收数据类型为JavaRDD<LabeledPoint>)

3.分类predict返回结果为类别标签,贝叶斯模型可返回属于不同类的概率(python没用该接口)

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=326397909&siteId=291194637