Spark MLlib - Iris(鸢尾花卉)数据集 LogisticRegression(逻辑回归)

所用数据:http://download.csdn.net/download/dr_guo/9946656
环境版本:Spark 1.6.1; Scala 2.10.4; JDK 1.7

详见注释

package com.beagledata.test

import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD,LogisticRegressionWithLBFGS}
import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.log4j.{Level,Logger}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SQLContext
import org.apache.spark.mllib.linalg.{Vector, Vectors}  
import org.apache.spark.mllib.regression.LabeledPoint


object IrisLogisticModelTest extends App{

  val conf = new SparkConf().setAppName("IrisLogisticModelTest")
  .setMaster("local")

  val sc = new SparkContext(conf)
  val sqlContext = new SQLContext(sc)
  import sqlContext.implicits._

  // load data  
  val rddIris = sc.textFile("data/IrisData2.txt")
  //rddIris.foreach(println)

  case class Iris(a:Double, b:Double, c:Double, d:Double, target:Double)

  //LabeledPoint中的label即target列必须是double类型,从0.0开始,两类就是0.01.0
  val dfIris = rddIris.map(_.split(",")).map(l => Iris(l(0).toDouble,l(1).toDouble,l(2).toDouble,l(3).toDouble,l(4).toDouble)).toDF()

  dfIris.registerTempTable("Iris")  

  //sqlContext.sql("""SELECT * FROM Iris""").show

  // Map feature names to indices
  val featInd = List("a", "b", "c", "d").map(dfIris.columns.indexOf(_))

  // Get index of target
  val targetInd = dfIris.columns.indexOf("target") 

  val labeledPointIris = dfIris.rdd.map(r => LabeledPoint(
   r.getDouble(targetInd), // Get target value
   // Map feature indices to values
   Vectors.dense(featInd.map(r.getDouble(_)).toArray)))

  // Split data into training (80%) and test (20%).
  val splits = labeledPointIris.randomSplit(Array(0.8, 0.2), seed = 11L)
  val trainingData = splits(0)
  val testData = splits(1)
  /*println("trainingData--------------------------->")
  trainingData.take(5).foreach(println)
  println("testData------------------------------->")
  testData.take(5).foreach(println)*/


  /*// Run training algorithm to build the model  
  val lr = new LogisticRegressionWithSGD().setIntercept(true)
  lr.optimizer
    .setStepSize(10.0)
    .setRegParam(0.0)
    .setNumIterations(20)
    .setConvergenceTol(0.0005)
  val model = lr.run(trainingData)*/
  val numiteartor = 2
  //val model = LogisticRegressionWithSGD.train(trainingData, numiteartor)    
  val model = new LogisticRegressionWithLBFGS().setNumClasses(numiteartor).run(trainingData)

  //预测
   val labelAndPreds = testData.map { point =>
    val prediction = model.predict(point.features)
    (point.label, prediction)

  }
  println("labelAndPreds------------------------->")
  labelAndPreds.take(5).foreach(println)
  //计算准确率
  val metrics = new MulticlassMetrics(labelAndPreds)
  val precision = metrics.precision
  println("Precision = " + precision)

}
遇到了两个错误:

1.Input validation failed

经检查发现数据集中的label不是从0开始的,LabeledPoint需要二分类数据的标签是double类型,且从0.0开始,即标签为0.0和1.0,不能为1.0和2.0等。

2.bad symbolic reference. A signature in GeneralizedLinearAlgorithm.class refers to term internal
in package org.apache.spark which is not available. It may be completely missing from the current classpath,
or the version on the classpath might be incompatible with the version used when compiling GeneralizedLinearAlgorithm.class.

多加了spark-mllib jar包,spark-assembly-1.6.1-hadoop2.6.0.jar中包含了spark-mllib。重复冲突了,删掉后解决。

猜你喜欢

转载自blog.csdn.net/dr_guo/article/details/77506641