基于spark用线性回归(linear regression)进行数据预测(谋杀率)

分类和聚类算法很多,但是对数据进行精准预测的算法不是很多,这里介绍下最常用的线性回归例子

spark官网上给出的例子不够全面,网上有些例子也不够具体,自己整理了一份

我的开发环境是ubuntu+IDEA+spark+scala

训练数据LR_data如下

3615, 3624, 2.1, 69.05, 15.1, 41.3, 20, 50708
365, 6315, 1.5, 69.31, 11.3, 66.7, 152, 566432
2212, 4530, 1.8, 70.55, 7.8, 58.1, 15, 113417
2110, 3378, 1.9, 70.66, 10.1, 39.9, 65, 51945
21198, 5114, 1.1, 71.71, 10.3, 62.6, 20, 156361
2541, 4884, 0.7, 72.06, 6.8, 63.9, 166, 103766
3100, 5348, 1.1, 72.48, 3.1, 56, 139, 4862
579, 4809, 0.9, 70.06, 6.2, 54.6, 103, 1982
8277, 4815, 1.3, 70.66, 10.7, 52.6, 11, 54090
4931, 4091, 2, 68.54, 13.9, 40.6, 60, 58073
868, 4963, 1.9, 73.6, 6.2, 61.9, 0, 6425
813, 4119, 0.6, 71.87, 5.3, 59.5, 126, 82677
11197, 5107, 0.9, 70.14, 10.3, 52.6, 127, 55748
5313, 4458, 0.7, 70.88, 7.1, 52.9, 122, 36097
2861, 4628, 0.5, 72.56, 2.3, 59, 140, 55941
2280, 4669, 0.6, 72.58, 4.5, 59.9, 114, 81787
3387, 3712, 1.6, 70.1, 10.6, 38.5, 95, 39650
3806, 3545, 2.8, 68.76, 13.2, 42.2, 12, 44930
1058, 3694, 0.7, 70.39, 2.7, 54.7, 161, 30920
4122, 5299, 0.9, 70.22, 8.5, 52.3, 101, 9891
5814, 4755, 1.1, 71.83, 3.3, 58.5, 103, 7826
9111, 4751, 0.9, 70.63, 11.1, 52.8, 125, 56817
3921, 4675, 0.6, 72.96, 2.3, 57.6, 160, 79289
2341, 3098, 2.4, 68.09, 12.5, 41, 50, 47296
4767, 4254, 0.8, 70.69, 9.3, 48.8, 108, 68995
746, 4347, 0.6, 70.56, 5, 59.2, 155, 145587
1544, 4508, 0.6, 72.6, 2.9, 59.3, 139, 76483
590, 5149, 0.5, 69.03, 11.5, 65.2, 188, 109889
812, 4281, 0.7, 71.23, 3.3, 57.6, 174, 9027
7333, 5237, 1.1, 70.93, 5.2, 52.5, 115, 7521
1144, 3601, 2.2, 70.32, 9.7, 55.2, 120, 121412
18076, 4903, 1.4, 70.55, 10.9, 52.7, 82, 47831
5441, 3875, 1.8, 69.21, 11.1, 38.5, 80, 48798
637, 5087, 0.8, 72.78, 1.4, 50.3, 186, 69273
10735, 4561, 0.8, 70.82, 7.4, 53.2, 124, 40975
2715, 3983, 1.1, 71.42, 6.4, 51.6, 82, 68782
2284, 4660, 0.6, 72.13, 4.2, 60, 44, 96184
11860, 4449, 1, 70.43, 6.1, 50.2, 126, 44966
931, 4558, 1.3, 71.9, 2.4, 46.4, 127, 1049
2816, 3635, 2.3, 67.96, 11.6, 37.8, 65, 30225
681, 4167, 0.5, 72.08, 1.7, 53.3, 172, 75955
4173, 3821, 1.7, 70.11, 11, 41.8, 70, 41328
12237, 4188, 2.2, 70.9, 12.2, 47.4, 35, 262134
1203, 4022, 0.6, 72.9, 4.5, 67.3, 137, 82096
472, 3907, 0.6, 71.64, 5.5, 57.1, 168, 9267
4981, 4701, 1.4, 70.08, 9.5, 47.8, 85, 39780
3559, 4864, 0.6, 71.72, 4.3, 63.5, 32, 66570
1799, 3617, 1.4, 69.48, 6.7, 41.6, 100, 24070
4589, 4468, 0.7, 72.48, 3, 54.5, 149, 54464
376, 4566, 0.6, 70.29, 6.9, 62.9, 173, 97203

预测数据LR_data_for_predict如下(在上面数据的基础上,对predict_target数据和第一个特征数据进行了修改,只修改了前10个记录。由于predict_target是需要预测的数据,事先不知道,但是为了保证数据格式,需要加上,这里默认为0,不影响。)

615, 3624, 2.1, 69.05, 0, 41.3, 20, 50708
65, 6315, 1.5, 69.31, 0, 66.7, 152, 566432
212, 4530, 1.8, 70.55, 0, 58.1, 15, 113417
110, 3378, 1.9, 70.66, 0, 39.9, 65, 51945
1198, 5114, 1.1, 71.71,0, 62.6, 20, 156361
541, 4884, 0.7, 72.06, 0, 63.9, 166, 103766
100, 5348, 1.1, 72.48, 0, 56, 139, 4862
79, 4809, 0.9, 70.06, 0, 54.6, 103, 1982
277, 4815, 1.3, 70.66, 0, 52.6, 11, 54090
931, 4091, 2, 68.54, 0, 40.6, 60, 58073
868, 4963, 1.9, 73.6, 6.2, 61.9, 0, 6425
813, 4119, 0.6, 71.87, 5.3, 59.5, 126, 82677
11197, 5107, 0.9, 70.14, 10.3, 52.6, 127, 55748
5313, 4458, 0.7, 70.88, 7.1, 52.9, 122, 36097
2861, 4628, 0.5, 72.56, 2.3, 59, 140, 55941
2280, 4669, 0.6, 72.58, 4.5, 59.9, 114, 81787
3387, 3712, 1.6, 70.1, 10.6, 38.5, 95, 39650
3806, 3545, 2.8, 68.76, 13.2, 42.2, 12, 44930
1058, 3694, 0.7, 70.39, 2.7, 54.7, 161, 30920
4122, 5299, 0.9, 70.22, 8.5, 52.3, 101, 9891
5814, 4755, 1.1, 71.83, 3.3, 58.5, 103, 7826
9111, 4751, 0.9, 70.63, 11.1, 52.8, 125, 56817
3921, 4675, 0.6, 72.96, 2.3, 57.6, 160, 79289
2341, 3098, 2.4, 68.09, 12.5, 41, 50, 47296
4767, 4254, 0.8, 70.69, 9.3, 48.8, 108, 68995
746, 4347, 0.6, 70.56, 5, 59.2, 155, 145587
1544, 4508, 0.6, 72.6, 2.9, 59.3, 139, 76483
590, 5149, 0.5, 69.03, 11.5, 65.2, 188, 109889
812, 4281, 0.7, 71.23, 3.3, 57.6, 174, 9027
7333, 5237, 1.1, 70.93, 5.2, 52.5, 115, 7521
1144, 3601, 2.2, 70.32, 9.7, 55.2, 120, 121412
18076, 4903, 1.4, 70.55, 10.9, 52.7, 82, 47831
5441, 3875, 1.8, 69.21, 11.1, 38.5, 80, 48798
637, 5087, 0.8, 72.78, 1.4, 50.3, 186, 69273
10735, 4561, 0.8, 70.82, 7.4, 53.2, 124, 40975
2715, 3983, 1.1, 71.42, 6.4, 51.6, 82, 68782
2284, 4660, 0.6, 72.13, 4.2, 60, 44, 96184
11860, 4449, 1, 70.43, 6.1, 50.2, 126, 44966
931, 4558, 1.3, 71.9, 2.4, 46.4, 127, 1049
2816, 3635, 2.3, 67.96, 11.6, 37.8, 65, 30225
681, 4167, 0.5, 72.08, 1.7, 53.3, 172, 75955
4173, 3821, 1.7, 70.11, 11, 41.8, 70, 41328
12237, 4188, 2.2, 70.9, 12.2, 47.4, 35, 262134
1203, 4022, 0.6, 72.9, 4.5, 67.3, 137, 82096
472, 3907, 0.6, 71.64, 5.5, 57.1, 168, 9267
4981, 4701, 1.4, 70.08, 9.5, 47.8, 85, 39780
3559, 4864, 0.6, 71.72, 4.3, 63.5, 32, 66570
1799, 3617, 1.4, 69.48, 6.7, 41.6, 100, 24070
4589, 4468, 0.7, 72.48, 3, 54.5, 149, 54464
376, 4566, 0.6, 70.29, 6.9, 62.9, 173, 97203

实现代码如下

/**
  * Created by wangtuntun on 17-3-8.
  * 利用线性回归模型对数据进行回归预测
  */


import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.LinearRegression
object linnear_regression_online {
  def main(args: Array[String]): Unit = {

    //设置环境
    val conf=new SparkConf().setAppName("tianchi").setMaster("local")
    val sc=new SparkContext(conf)
    val sqc=new SQLContext(sc)

    //准备训练集合
    val raw_data=sc.textFile("/home/wangtuntun/LR_data")
    val map_data=raw_data.map{x=>
      val split_list=x.split(",")
      (split_list(0).toDouble,split_list(1).toDouble,split_list(2).toDouble,split_list(3).toDouble,split_list(4).toDouble,split_list(5).toDouble,split_list(6).toDouble,split_list(7).toDouble)
    }
    val df=sqc.createDataFrame(map_data)
    val data = df.toDF("Population", "Income", "Illiteracy", "LifeExp", "Murder", "HSGrad", "Frost", "Area")
    val colArray = Array("Population", "Income", "Illiteracy", "LifeExp", "HSGrad", "Frost", "Area")
    val assembler = new VectorAssembler().setInputCols(colArray).setOutputCol("features")
    val vecDF: DataFrame = assembler.transform(data)

    //准备预测集合
    val raw_data_predict=sc.textFile("/home/wangtuntun/LR_data_for_predict")
    val map_data_for_predict=raw_data_predict.map{x=>
      val split_list=x.split(",")
      (split_list(0).toDouble,split_list(1).toDouble,split_list(2).toDouble,split_list(3).toDouble,split_list(4).toDouble,split_list(5).toDouble,split_list(6).toDouble,split_list(7).toDouble)
    }
    val df_for_predict=sqc.createDataFrame(map_data_for_predict)
    val data_for_predict = df_for_predict.toDF("Population", "Income", "Illiteracy", "LifeExp", "Murder", "HSGrad", "Frost", "Area")
    val colArray_for_predict = Array("Population", "Income", "Illiteracy", "LifeExp", "HSGrad", "Frost", "Area")
    val assembler_for_predict = new VectorAssembler().setInputCols(colArray_for_predict).setOutputCol("features")
    val vecDF_for_predict: DataFrame = assembler_for_predict.transform(data_for_predict)

    // 建立模型,预测谋杀率Murder
    // 设置线性回归参数
    val lr1 = new LinearRegression()
    val lr2 = lr1.setFeaturesCol("features").setLabelCol("Murder").setFitIntercept(true)
    // RegParam:正则化
    val lr3 = lr2.setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
    val lr = lr3

    // 将训练集合代入模型进行训练
    val lrModel = lr.fit(vecDF)

    // 输出模型全部参数
    lrModel.extractParamMap()
    // Print the coefficients and intercept for linear regression
    println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

    // 模型进行评价
    val trainingSummary = lrModel.summary
    println(s"numIterations: ${trainingSummary.totalIterations}")
    println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
    trainingSummary.residuals.show()
    println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
    println(s"r2: ${trainingSummary.r2}")


    val predictions: DataFrame = lrModel.transform(vecDF_for_predict)
    //    val predictions = lrModel.transform(vecDF)
    println("输出预测结果")
    val predict_result: DataFrame =predictions.selectExpr("features","Murder", "round(prediction,1) as prediction")
    predict_result.foreach(println(_))
    sc.stop()
  }
}

猜你喜欢

转载自blog.csdn.net/BigData_Mining/article/details/81142192