MLlib之KNN算法实例

MLlib之KNN算法实例

knn算法的思想:
邻近算法,或者最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表(近朱者赤近墨者黑)。
求距离公式:
曼哈顿距离
欧几里得距离

需求:
样本数据,
label,f1,f2,f3,f4,f5
0,10,20,30,40,30
0,12,22,29,42,35
0,11,21,31,40,34
0,13,22,30,42,32
0,12,22,32,41,33
0,10,21,33,45,35
1,30,11,21,40,34
1,33,10,20,43,30
1,30,12,23,40,33
1,32,10,20,42,33
1,30,13,20,42,30
1,30,09,22,41,32

– 用spark编程,对下列类别未知的向量,标注预测的类别
b1,b2,b3,b4,b5
11,21,31,44,32
14,26,32,39,30
32,14,21,42,32
34,12,22,42,34

代码实现:
object suanfa {
def main(args: Array[String]): Unit = {
//创建一个spark环境

val spark: SparkSession = SparkSession.builder()
  .appName(this.getClass.getSimpleName)
  .master("local[*]")
  .getOrCreate()

利用new StructType()自定义导入的数据类型

    //自定义样本的数据类型
val schemal: StructType = new StructType()
  .add("label", DataTypes.DoubleType)
  .add("f1", DataTypes.DoubleType)
  .add("f2", DataTypes.DoubleType)
  .add("f3", DataTypes.DoubleType)
  .add("f4", DataTypes.DoubleType)
  .add("f5", DataTypes.DoubleType)

  //自定义未知数据的数据类型
val schema2: StructType = new StructType()
  .add("id", DataTypes.DoubleType)
  .add("b1", DataTypes.DoubleType)
  .add("b2", DataTypes.DoubleType)
  .add("b3", DataTypes.DoubleType)
  .add("b4", DataTypes.DoubleType)
  .add("b5", DataTypes.DoubleType)

两个数据做crossjoin笛卡尔积的join,一对多的放在一起形成zhb(综合表)

//导入样本数据导入
val yb: DataFrame = spark.read.option("header","true").schema(schemal).csv("data/demo/yangben")
//导入为知数据
val wz: DataFrame = spark.read.schema(schema2).option("header","true").csv("data/demo/weizhi")
    //将样本数据和为主数据连接crossjion(交叉join  笛卡尔积join)
    val zhj: DataFrame = wz.crossJoin(yb)

利用 vectors.sqdist 自己先用 udf() 弄一个两个空数据的欧式距离计算表
运行时会报Failed to execute user defined function(数据执行异常,数据不匹配)的错
解决办法:将自己写的普通Array改成mutable.WrappedArray,在dense的时候在改回来

        //自定义一个计算欧式距离的函数  
   import org.apache.spark.sql.functions._
    val osjl: UserDefinedFunction = udf((arr1:mutable.WrappedArray[Double], arr2:mutable.WrappedArray[Double]) =>{
       val v1: linalg.Vector = Vectors.dense(arr1.toArray)
       val v2: linalg.Vector = Vectors.dense(arr2.toArray)
        Vectors.sqdist(v1,v2)//sq(平方)dist(距离)只不过得到的最后的结果的平方不影响计算结果
    })

利用综合表.select()的方法将两个数据传入上方设定的欧式距离计算表

//计算样本和未知之间的距离
    import spark.implicits._
val fra: DataFrame = zhj.select(
  $"label", //如果想用符号就必须用隐式转换(import spark.implicits._),
  //col("label")如果想用这个就必须用(import org.apache.spark.sql.functions._)
  $"id",
  //这个数据不是普通数组是WrappedArray会报错
  osjl(array("f1", "f2", "f3", "f4", "f5"), array("b1", "b2", "b3", "b4", "b5"))as "dist"
)

创建sparksql处理表格

//处理表格找出距离最小的5名
fra.createTempView("top")
    val top5 = spark.sql(
      """
        |
        |select
        |*
        |from
        |(
        |    select
        |    id,
        |    label,
        |    dist,
        |    row_number() over(partition by id order by dist)as rn
        |    from
        |    top
        |)t
        |where rn<6
        |order by id
        |
      """.stripMargin)

    //在距离最小的5名里找出次数最多的label

top5.createTempView("top5")
spark.sql(
  """
    |
    |select
    |id,
    |label
    |from
    |top5
    |group by id,label
    |having count(1)> 2
    |
  """.stripMargin)

      .show(100,false)

    //关流
      spark.stop()
  }
}
发布了48 篇原创文章 · 获赞 11 · 访问量 1506

猜你喜欢

转载自blog.csdn.net/weixin_45896475/article/details/104426110
今日推荐