SparkMLlib implements K-means
introduction
I wrote a blog about kmeans before, which introduced a detailed description of K-means in detail. It is implemented in python, and the data is attached at the end. To understand the changes about K-means, see K-means.
Today, use spark in scala language and use MLlib library to achieve
rely
<!--mllib依赖,我用的是scala是2.11, spark是2.2.0-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>1.6.0</version>
</dependency>
note
Compared with python, it is the same function adjustment, but the data type fed to the model is different from python. The SKLearing library in python uses a matrix or a DataFrame. The data required in spark is of RDD[Vector] type.
/**
* Trains a k-means model using specified parameters and the default values for unspecified.
* 源码中最简单的训练函数要求的是传递三个参数,分别是数据集、族数、迭代次数
*/
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int): KMeansModel = {
train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
}
Code
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.rdd.RDD
object Kmeans {
def main(args: Array[String]): Unit = {
//模板代码,指定两个线程模拟在hadoop端的分布式
val conf = new SparkConf().setAppName("Kmeans").setMaster("local[2]")
val sc = new SparkContext(conf)
//加载数据
val data = sc.textFile("F:/mllib/kmeans/trainsdata")
//将数据切分成标志格式,并封装成linalg.Vector类型
val parsedData: RDD[linalg.Vector] = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
//迭代次数1000次、类簇的个数2个,进行模型训练形成数据模型
val numClusters = 4
val numIterations = 1000
//进行训练
val model = KMeans.train(parsedData, numClusters, numIterations)
//打印数据模型的中心点
println("四个中心的点:")
for (point <- model.clusterCenters) {
println(" " + point.toString)
}
//使用误差平方之和来评估数据模型,统计聚类错误的样本比例
val cost = model.computeCost(parsedData)
println("聚类错误的样本比例 = " + cost)
//对部分点做预测分类
println("点(-3 -3)所属族:" + model.predict(Vectors.dense("-3 -3".split(' ').map(_.toDouble))))
println("点(-2 3)所属族:" + model.predict(Vectors.dense("-2 3".split(' ').map(_.toDouble))))
println("点(3 3)所属族:" + model.predict(Vectors.dense("3 3".split(' ').map(_.toDouble))))
sc.stop()
}
}
operation result
四个中心的点:
[-2.4615431500000002,2.78737555]
[-3.3823704500000007,-2.9473363000000004]
[2.6265298999999995,3.10868015]
[2.80293085,-2.7315146]
聚类错误的样本比例 = 149.95430467642632
点(-3 -3)所属族:1
点(-2 3)所属族:0
点(3 3)所属族:2