Spark 实现KNN算法(二)

Spark 实现KNN算法 – 基于RDD

上一篇 基于DataFrame实现KNN的过程中,由于中间使用了笛卡尔积,以及大规模的排序,对于运算的性能有较大影响,经过一定的调整,笔者找到一个相对较好的实现方法

  def runKnn(trainSet: DataFrame, testSet: DataFrame, k: Int, cl: String) = {

    val testFetures: RDD[Seq[Double]] = testSet
      .drop(cl).map(row => {
      val fetuers: Seq[Double] = row.mkString(",").split(",").map(_.toDouble)
      fetuers
    }).rdd

    val trainFetures: RDD[(String, Seq[Double])] = trainSet.map(row => {
      val cla = row.getAs[String](cl)
      val fetuers: Seq[Double] = row.mkString(",")
        .split(",").filter(NumberUtils.isNumber(_)).map(_.toDouble)
      (cla, fetuers)
    }).rdd

    // 将训练集广播
    val trainBroad = spark.sparkContext.broadcast(trainFetures.collect())

    val resRDD: RDD[Row] = testFetures.map(testTp => {
      //定义一个TreeSet之前 先自定义一个排序规则
      val orderRules: Ordering[(String, Double)] = Ordering.fromLessThan[(String, Double)](_._2 <= _._2)
      //新建一个空的set 传入排序规则
      var set: mutable.TreeSet[(String, Double)] = mutable.TreeSet.empty(orderRules)

      trainBroad.value.foreach(trainTp => {
        val dist = distance.Euclidean(testTp, trainTp._2)
        set += (trainTp._1 -> dist)
        // 设定了set的大小,排序的时候更高效
        if (set.size > k) set = set.slice(0, k) else set
      })

      // 获取 投票数最多的类  (一个Wordcount)
      val cla = set.toArray.groupBy(_._1)
        .map(t => (t._1, t._2.length)).maxBy(_._2)._1

      Row.merge(Row.fromSeq(testTp), Row(cla))

    })

    spark.createDataFrame(resRDD, trainSet.schema)

  }

算法测试

val iris = spark.read
      .option("header", true)
      .option("inferSchema", true)
      .csv(inputFile)

   // 将鸢尾花分成两部分:训练集和测试集
    val Array(testSet, trainSet) = iris.randomSplit(Array(0.3, 0.7), 1234L)

     val knnMode2 = new KNNRunner(spark)
    val res2 = knnMode2.runKnn(trainSet, testSet, 10, "class")
    
    
    res2.show(truncate = false)
    val check = udf((f1: String, f2: String) => {
      if (f1.equals(f2)) 1 else 0
    })
 
    res2.join(testSet.withColumnRenamed("class", "yclass"),
      Seq("sepalLength", "sepalWidth", "petalLength", "petalWidth"))
      .withColumn("check", check($"class", $"yclass"))
       .groupBy("check").count().show()
 
+-----------+----------+-----------+----------+---------------+
|sepalLength|sepalWidth|petalLength|petalWidth|class          |
+-----------+----------+-----------+----------+---------------+
|4.6        |3.2       |1.4        |0.2       |Iris-setosa    |
|4.8        |3.0       |1.4        |0.1       |Iris-setosa    |
|4.8        |3.4       |1.6        |0.2       |Iris-setosa    |

+-----+-----+
|check|count|
+-----+-----+
|    1|   53|
|    0|    2|
+-----+-----+

从结果看,两个实现过程是一致的,但是本文使用的方法更高效。

猜你喜欢

转载自blog.csdn.net/k_wzzc/article/details/84310993