SparkMLlib实现K-means

SparkMLlib实现K-means

引言

之前写过一篇关于kmeans的博客,里面详细的介绍了关于K-means的的详细描述,用python是实现的,并且在最后附带数据,了解更改关于K-means的内容详看K-means

今天用scala语言中的spark,使用MLlib库来实现

依赖

<!--mllib依赖,我用的是scala是2.11, spark是2.2.0-->
<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-mllib_2.10</artifactId>
    <version>1.6.0</version>
</dependency>

注意

和python相比,一样是调函数调参,但喂给model的数据类型和python不同,python中的SKLearing库使用的是矩阵或者是DataFrame,在spark里边要求的data是RDD[Vector]类型

/**
 * Trains a k-means model using specified parameters and the default values for unspecified.
 * 源码中最简单的训练函数要求的是传递三个参数,分别是数据集、族数、迭代次数
 */
def train(
      data: RDD[Vector],
      k: Int,
      maxIterations: Int): KMeansModel = {
    train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
  }

代码

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.rdd.RDD


object Kmeans {
  def main(args: Array[String]): Unit = {
    //模板代码,指定两个线程模拟在hadoop端的分布式
    val conf = new SparkConf().setAppName("Kmeans").setMaster("local[2]")
    val sc = new SparkContext(conf)

    //加载数据
    val data = sc.textFile("F:/mllib/kmeans/trainsdata")
    //将数据切分成标志格式,并封装成linalg.Vector类型
    val parsedData: RDD[linalg.Vector] = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))

    //迭代次数1000次、类簇的个数2个,进行模型训练形成数据模型
    val numClusters = 4
    val numIterations = 1000
      
    //进行训练
    val model = KMeans.train(parsedData, numClusters, numIterations)

    //打印数据模型的中心点
    println("四个中心的点:")
    for (point <- model.clusterCenters) {
      println("  " + point.toString)
    }

    //使用误差平方之和来评估数据模型,统计聚类错误的样本比例
    val cost = model.computeCost(parsedData)
    println("聚类错误的样本比例 = " + cost)

    //对部分点做预测分类
   println("点(-3 -3)所属族:" + model.predict(Vectors.dense("-3 -3".split(' ').map(_.toDouble))))
    println("点(-2 3)所属族:" + model.predict(Vectors.dense("-2 3".split(' ').map(_.toDouble))))
    println("点(3 3)所属族:" + model.predict(Vectors.dense("3 3".split(' ').map(_.toDouble))))
    
    sc.stop()
  }
}

运行结果

四个中心的点:
  [-2.4615431500000002,2.78737555]
  [-3.3823704500000007,-2.9473363000000004]
  [2.6265298999999995,3.10868015]
  [2.80293085,-2.7315146]
聚类错误的样本比例 = 149.95430467642632
点(-3 -3)所属族:1
点(-2 3)所属族:0
点(3 3)所属族:2

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/jklcl/article/details/84101883
今日推荐