SparkMLlib逻辑斯蒂回归分类器简单案例

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/monkeysheep1234/article/details/69487381
package com.huihex.sparkmllib

import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}
/**
  * Created by wall-e on 2017/4/1.
  */
object Logistic_regression {
  /**
    * 逻辑斯蒂回归(logistic regression)是统计学习中的经典分类方法,属于对数线性模型。
    * logistic回归的因变量可以是二分类的,也可以是多分类的。
    * @param args
    */
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("逻辑回归").setMaster("local")
    val sc = new SparkContext(conf)
    //读取数据
    //每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类
    //这里我们用LabeledPoint来存储标签列和特征列
    val data = sc.textFile("data\\iris.txt")
    /*
    LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。
    这里,先把莺尾花的分类进行变换,"Iris-setosa"对应分类0,"Iris-versicolor"对应分类1,其余对应分类2;
    然后获取莺尾花的4个特征,存储在Vector中。
     */
    val parsedData = data.map { line =>
      val parts = line.split(',')
      LabeledPoint(
        if(parts(4)=="Iris-setosa") 0.toDouble
        else if (parts(4) =="Iris-versicolor") 1.toDouble
        else 2.toDouble,
        Vectors.dense(parts(0).toDouble,parts(1).toDouble,parts(2).toDouble,parts(3).toDouble)
      )
    }
    //打印读取并处理后的数据
    parsedData.foreach { x => println(x) }

    /**
      * 首先进行数据集的划分,这里划分60%的训练集和40%的测试集:
      */
    val splits = parsedData.randomSplit(Array(0.6,0.4),seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)
    /**
      * 然后,构建逻辑斯蒂模型,用set的方法设置参数,比如说分类的数目,这里可以实现多分类逻辑斯蒂模型
      */
    val model = new LogisticRegressionWithLBFGS().setNumClasses(3).run(training)
    /**
      * 接下来,调用多分类逻辑斯蒂模型用的predict方法对测试数据进行预测,并把结果保存在MulticlassMetrics中。
      * 这里的模型全名为LogisticRegressionWithLBFGS,加上了LBFGS,表示Limited-memory BFGS。
      * 其中,BFGS是求解非线性优化问题(L(w)​求极大值)的方法,是一种秩-2更新,
      * 以其发明者Broyden, Fletcher, Goldfarb和Shanno的姓氏首字母命名。
      */
    val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
      val prediction = model.predict(features)
      (prediction, label)
    }
    /**
      * 这里,采用了test部分的数据每一行都分为标签label和特征features,
      * 然后利用map方法,对每一行的数据进行model.predict(features)操作,获得预测值。
      * 并把预测值和真正的标签放到predictionAndLabels中。我们可以打印出具体的结果数据来看一下:
      */
    predictionAndLabels.foreach(x =>(println(x)))
    /**
      * 模型评估
      * 模型预测的准确性
      */
    val metrics = new MulticlassMetrics(predictionAndLabels)
    val precision = metrics.precision
    println("Precision = " + precision)
  }
}

iris数据集下载链接(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data

......
17/04/06 22:16:29 INFO SparkContext: Created broadcast 1 from broadcast at DAGScheduler.scala:996
17/04/06 22:16:29 INFO DAGScheduler: Submitting 1 missing tasks from ResultStage 0 (MapPartitionsRDD[2] at map at Logistic_regression.scala:29)
17/04/06 22:16:29 INFO TaskSchedulerImpl: Adding task set 0.0 with 1 tasks
17/04/06 22:16:29 INFO TaskSetManager: Starting task 0.0 in stage 0.0 (TID 0, localhost, executor driver, partition 0, PROCESS_LOCAL, 5983 bytes)
17/04/06 22:16:29 INFO Executor: Running task 0.0 in stage 0.0 (TID 0)
17/04/06 22:16:29 INFO HadoopRDD: Input split: file:/D:/huihex-spark/data/iris.txt:0+4698
17/04/06 22:16:29 INFO deprecation: mapred.tip.id is deprecated. Instead, use mapreduce.task.id
17/04/06 22:16:29 INFO deprecation: mapred.task.id is deprecated. Instead, use mapreduce.task.attempt.id
17/04/06 22:16:29 INFO deprecation: mapred.task.is.map is deprecated. Instead, use mapreduce.task.ismap
17/04/06 22:16:29 INFO deprecation: mapred.task.partition is deprecated. Instead, use mapreduce.task.partition
17/04/06 22:16:29 INFO deprecation: mapred.job.id is deprecated. Instead, use mapreduce.job.id
(0.0,[5.1,3.5,1.4,0.2])
(0.0,[4.9,3.0,1.4,0.2])
(0.0,[4.7,3.2,1.3,0.2])
(0.0,[4.6,3.1,1.5,0.2])
......
(1.0,1.0)
(1.0,1.0)
(1.0,1.0)
(1.0,1.0)
17/04/06 22:16:33 INFO Executor: Finished task 0.0 in stage 72.0 (TID 72). 995 bytes result sent to driver
17/04/06 22:16:33 INFO TaskSetManager: Finished task 0.0 in stage 72.0 (TID 72) in 21 ms on localhost (executor driver) (1/1)
17/04/06 22:16:33 INFO TaskSchedulerImpl: Removed TaskSet 72.0, whose tasks have all completed, from pool 
17/04/06 22:16:33 INFO DAGScheduler: ResultStage 72 (foreach at Logistic_regression.scala:66) finished in 0.022 s
(2.0,2.0)
(2.0,2.0)
(2.0,2.0)
......
17/04/06 22:16:33 INFO Executor: Running task 0.0 in stage 76.0 (TID 76)
17/04/06 22:16:33 INFO ShuffleBlockFetcherIterator: Getting 1 non-empty blocks out of 1 blocks
17/04/06 22:16:33 INFO ShuffleBlockFetcherIterator: Started 0 remote fetches in 1 ms
Precision = 0.9615384615384616
17/04/06 22:16:33 INFO Executor: Finished task 0.0 in stage 76.0 (TID 76). 1807 bytes result sent to driver
17/04/06 22:16:33 INFO TaskSetManager: Finished task 0.0 in stage 76.0 (TID 76) in 10 ms on localhost (executor driver) (1/1)
17/04/06 22:16:33 INFO TaskSchedulerImpl: Removed TaskSet 76.0, whose tasks have all completed, from pool 

参考文档:http://mocom.xmu.edu.cn/article/show/58578f482b2730e00d70f9fc/0/1

猜你喜欢

转载自blog.csdn.net/monkeysheep1234/article/details/69487381
今日推荐