Spark ML逻辑回归

 1 import org.apache.log4j.{Level, Logger}
 2 import org.apache.spark.ml.classification.LogisticRegression
 3 import org.apache.spark.ml.linalg.Vectors
 4 import org.apache.spark.sql.SparkSession
 5 
 6 /**
 7   * 逻辑回归
 8   * Created by zhen on 2018/11/20.
 9   */
10 object LogisticRegression {
11   Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别
12   def main(args: Array[String]) {
13     val spark = SparkSession.builder()
14       .appName("LogisticRegression")
15       .master("local[2]")
16       .getOrCreate()
17     val sqlContext = spark.sqlContext
18     // 加载训练数据和测试数据
19     val data = sqlContext.createDataFrame(Seq(
20       (1.0, Vectors.dense(0.0, 1.1, 0.1)),
21       (0.0, Vectors.dense(2.0, 1.0, -1.1)),
22       (1.0, Vectors.dense(1.0, 2.1, 0.1)),
23       (0.0, Vectors.dense(2.0, -1.3, 1.1)),
24       (0.0, Vectors.dense(2.0, 1.0, -1.1)),
25       (1.0, Vectors.dense(1.0, 2.1, 0.1)),
26       (1.0, Vectors.dense(2.0, 1.3, 1.1)),
27       (0.0, Vectors.dense(-2.0, 1.0, -1.1)),
28       (1.0, Vectors.dense(1.0, 2.1, 0.1)),
29       (0.0, Vectors.dense(2.0, -1.3, 1.1)),
30       (1.0, Vectors.dense(2.0, 1.0, -1.1)),
31       (1.0, Vectors.dense(1.0, 2.1, 0.1)),
32       (0.0, Vectors.dense(-2.0, 1.3, 1.1)),
33       (1.0, Vectors.dense(0.0, 1.2, -0.4))
34     ))
35     .toDF("label", "features")
36     val weights = Array(0.8,0.2) //设置训练集和测试集的比例
37     val split_data = data.randomSplit(weights) // 拆分训练集和测试集
38     // 创建逻辑回归对象
39     val lr = new LogisticRegression()
40     // 设置参数
41     lr.setMaxIter(10).setRegParam(0.01)
42     // 训练模型
43     val model = lr.fit(split_data(0))
44     model.transform(split_data(1))
45     .select("label", "features", "probability", "prediction")
46     .collect()
47     .foreach(println(_))
48     //关闭spark
49     spark.stop()
50   }
51 }

结果:

猜你喜欢

转载自www.cnblogs.com/yszd/p/9988597.html