Spark 2.0 机器学习 ML 库:常见的机器学习模型(Scala 版)

一、前言

机器学习中,人为地设计算法,需要一定的知识积淀。
而使用别人设计好的机器学习库如 Spark 2.0 ML,那是基本不需要什么基础的,开箱即用。
首先,看一个简单、完整、规范的案例,无疑是最好的方式。

之前的文章(内含短小精悍的案例):
Spark 2.0 机器学习 ML 库:特征提取、转化、选取(Scala 版)
Spark 2.0 机器学习 ML 库:机器学习工作流、交叉 - 验证方法(Scala 版)
Spark 2.0 机器学习 ML 库:数据分析方法(Scala 版)

二、代码

下面的代码,来自网上,挺好的,笔者加以细化

1.线性回归

package change

import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql._
import org.apache.spark.sql.SparkSession

/**
  * 线性回归
  */
object linearTest {

  def main(args: Array[String]): Unit = {

    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
    import spark.implicits._

    //1 训练样本准备
    val training =  spark.createDataFrame(Seq(
      (5.601801561245534, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.6949189734965766, -0.32697929564739403, -0.15359663581829275, -0.8951865090520432, 0.2057889391931318, -0.6676656789571533, -0.03553655732400762, 0.14550349954571096, 0.034600542078191854, 0.4223352065067103))),
      (0.2577820163584905, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.8386555657374337, -0.1270180511534269, 0.499812362510895, -0.22686625128130267, -0.6452430441812433, 0.18869982177936828, -0.5804648622673358, 0.651931743775642, -0.6555641246242951, 0.17485476357259122))),
      (1.5299675726687754, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(-0.13079299081883855, 0.0983382230287082, 0.15347083875928424, 0.45507300685816965, 0.1921083467305864, 0.6361110540492223, 0.7675261182370992, -0.2543488202081907, 0.2927051050236915, 0.680182444769418))))).toDF("label", "features")
    training.show(false)

    //2 建立逻辑回归模型
    val lr = new LinearRegression()
      .setMaxIter(100)
      .setRegParam(0.1)
      .setElasticNetParam(0.5)

    //2 根据训练样本进行模型训练
    val lrModel = lr.fit(training)

    //2 打印模型信息
    println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

    /**
      * Coefficients: [0.0,-0.8840148895400428,-4.451571521834594,-0.42090140779272434,0.857395634491616,-1.237347818637769,0.0,0.0,0.0,0.0] Intercept: 3.1417724655192645
      */

    println(s"Intercept: ${lrModel.intercept}")

    /**
      * Intercept: 3.1417724655192645
      */

    //4 测试样本
    val test = spark.createDataFrame(Seq(
      (5.601801561245534, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.6949189734965766, -0.32697929564739403, -0.15359663581829275, -0.8951865090520432, 0.2057889391931318, -0.6676656789571533, -0.03553655732400762, 0.14550349954571096, 0.034600542078191854, 0.4223352065067103))),
      (0.2577820163584905, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.8386555657374337, -0.1270180511534269, 0.499812362510895, -0.22686625128130267, -0.6452430441812433, 0.18869982177936828, -0.5804648622673358, 0.651931743775642, -0.6555641246242951, 0.17485476357259122))),
      (1.5299675726687754, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(-0.13079299081883855, 0.0983382230287082, 0.15347083875928424, 0.45507300685816965, 0.1921083467305864, 0.6361110540492223, 0.7675261182370992, -0.2543488202081907, 0.2927051050236915, 0.680182444769418))))).toDF("label", "features")
    test.show(false)

    /**
      * +------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      * |label             |features                                                                                                                                                                                                                             |
      * +------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      * |5.601801561245534 |(10,[0,1,2,3,4,5,6,7,8,9],[0.6949189734965766,-0.32697929564739403,-0.15359663581829275,-0.8951865090520432,0.2057889391931318,-0.6676656789571533,-0.03553655732400762,0.14550349954571096,0.034600542078191854,0.4223352065067103])|
      * |0.2577820163584905|(10,[0,1,2,3,4,5,6,7,8,9],[0.8386555657374337,-0.1270180511534269,0.499812362510895,-0.22686625128130267,-0.6452430441812433,0.18869982177936828,-0.5804648622673358,0.651931743775642,-0.6555641246242951,0.17485476357259122])     |
      * |1.5299675726687754|(10,[0,1,2,3,4,5,6,7,8,9],[-0.13079299081883855,0.0983382230287082,0.15347083875928424,0.45507300685816965,0.1921083467305864,0.6361110540492223,0.7675261182370992,-0.2543488202081907,0.2927051050236915,0.680182444769418])       |
      * +------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      */

    //5 对模型进行测试
    val test_predict = lrModel.transform(test)
    test_predict
      .select("label","prediction","features")
      .show(false)

    /**
      * +------------------+-------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      * |label             |prediction         |features                                                                                                                                                                                                                             |
      * +------------------+-------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      * |5.601801561245534 |5.493935912726037  |(10,[0,1,2,3,4,5,6,7,8,9],[0.6949189734965766,-0.32697929564739403,-0.15359663581829275,-0.8951865090520432,0.2057889391931318,-0.6676656789571533,-0.03553655732400762,0.14550349954571096,0.034600542078191854,0.4223352065067103])|
      * |0.2577820163584905|0.33788027718672575|(10,[0,1,2,3,4,5,6,7,8,9],[0.8386555657374337,-0.1270180511534269,0.499812362510895,-0.22686625128130267,-0.6452430441812433,0.18869982177936828,-0.5804648622673358,0.651931743775642,-0.6555641246242951,0.17485476357259122])     |
      * |1.5299675726687754|1.557734960360036  |(10,[0,1,2,3,4,5,6,7,8,9],[-0.13079299081883855,0.0983382230287082,0.15347083875928424,0.45507300685816965,0.1921083467305864,0.6361110540492223,0.7675261182370992,-0.2543488202081907,0.2927051050236915,0.680182444769418])       |
      * +------------------+-------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      */

    //6 模型摘要
    val trainingSummary = lrModel.summary

    // 每次迭代目标值
    val objectiveHistory = trainingSummary.objectiveHistory
    println(s"numIterations: ${trainingSummary.totalIterations}")

    /**
      * numIterations: 101
      */

    println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]")

    /**
      * objectiveHistory: [0.5,0.3343138710353481,0.03498698228406803,0.034394702527331365,0.03361752133820051,0.033440576313009396,0.032492999827506586,0.03209249818672103,0.03201118276878801,0.0318030335506653,0.031556141484809515,0.03146914334471842,0.03132368104987874,0.030906857778152226,0.030829631969772512,0.030792601096269995,0.03075807300477159,0.03064409361649658,0.03057645418974434,0.03048720940080922,0.030450452329432418,0.0303403006892938,0.03022336621283447,0.030105231797686347,0.03005248564337978,0.029952523828252434,0.029901762708870988,0.029901114112460842,0.029897992643680316,0.029897097909156505,0.029892358780083193,0.029890487541861296,0.029883508098656905,0.02986342331315129,0.029846157576330717,0.02983921669719768,0.029837621981381814,0.029832343881027193,0.029818011565517288,0.0298174329753425,0.029816619127868163,0.029815897918569062,0.029815813156609985,0.029815635355907394,0.029814914126549,0.029813735638819686,0.02981357400967502,0.02981340129452729,0.029813363218666296,0.029813104482615992,0.029813066188642295,0.02981290111924657,0.029812867201451012,0.029812730285385426,0.029812706953398726,0.02981259780704471,0.02981258478371474,0.02981249810105761,0.029812492058484363,0.029812414896583955,0.02981239284306545,0.02981217952516655,0.029812093354005524,0.029812078847204722,0.02981204606864486,0.029812029284085127,0.029812008170753846,0.029812001127453244,0.02981198610905457,0.029811978179336476,0.029811968590860403,0.029811960922339894,0.02981195510843637,0.029811951516538388,0.02981194560589678,0.029811931971338676,0.029811927559300986,0.02981192583464405,0.029811923533256,0.02981192147493291,0.029811919101372975,0.0298119178536648,0.029811915692737362,0.02981191417259256,0.029811912340872517,0.02981191111669305,0.02981190922210416,0.029811908328486812,0.029811906376022823,0.029811905682559023,0.02981190386743857,0.029811903165691635,0.02981190159751578,0.029811901021202986,0.02981189985181355,0.02981189892054736,0.029811897724408266,0.02981189698790617,0.02981189562974597,0.029811894938092554,0.029811894064851477]
      */

    trainingSummary.residuals.show(false)

    /**
      * +---------------------+
      * |residuals            |
      * +---------------------+
      * |0.1078656485194962   |
      * |-0.08009826082823523 |
      * |-0.027767387691260526|
      * +---------------------+
      */

    println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")

    /**
      * RMSE: 0.07920807479341203
      */

    println(s"r2: ${trainingSummary.r2}")

    /**
      * r2: 0.998792363204057
      */

    //7 模型保存与加载(发布到服务器 django 时,View 加入如下代码 + 文件)
    lrModel.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\lrmodel2")
    val load_lrModel = LinearRegressionModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\lrmodel2")

  }

}

这里写图片描述

2.逻辑回归

package change

import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession

object logicTest {

  def main(args: Array[String]): Unit = {

    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
    import spark.implicits._

    //1 训练样本准备
    val training = spark.createDataFrame(Seq(
      (1.0, Vectors.sparse(692, Array(10, 20, 30), Array(-1.0, 1.5, 1.3))),
      (0.0, Vectors.sparse(692, Array(45, 175, 500), Array(-1.0, 1.5, 1.3))),
      (1.0, Vectors.sparse(692, Array(100, 200, 300), Array(-1.0, 1.5, 1.3))))).toDF("label", "features")
    training.show(false)

    /**
      * +-----+----------------------------------+
      * |label|features                          |
      * +-----+----------------------------------+
      * |1.0  |(692,[10,20,30],[-1.0,1.5,1.3])   |
      * |0.0  |(692,[45,175,500],[-1.0,1.5,1.3]) |
      * |1.0  |(692,[100,200,300],[-1.0,1.5,1.3])|
      * +-----+----------------------------------+
      */

    //2 建立逻辑回归模型
    val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)

    //2 根据训练样本进行模型训练
    val lrModel = lr.fit(training)

    //2 打印模型信息
    println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

    /**
      * Coefficients: (692,[45,175,500],[0.48944928041408226,-0.32629952027605463,-0.37649944647237077]) Intercept: 1.251662793530725
      */

    println(s"Intercept: ${lrModel.intercept}")

    /**
      * Intercept: 1.251662793530725
      */

    //3 建立多元回归模型
    val mlr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8).setFamily("multinomial")

    //3 根据训练样本进行模型训练
    val mlrModel = mlr.fit(training)

    //3 打印模型信息
    println(s"Multinomial coefficients: ${mlrModel.coefficientMatrix}")

    /**
      * Multinomial coefficients: 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ... (692 total)
      * 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...
      */

    println(s"Multinomial intercepts: ${mlrModel.interceptVector}")

    /**
      * Multinomial intercepts: [-0.6449310568167714,0.6449310568167714]
      */

    //4 测试样本
    val test = spark.createDataFrame(Seq(
      (1.0, Vectors.sparse(692, Array(10, 20, 30), Array(-1.0, 1.5, 1.3))),
      (0.0, Vectors.sparse(692, Array(45, 175, 500), Array(-1.0, 1.5, 1.3))),
      (1.0, Vectors.sparse(692, Array(100, 200, 300), Array(-1.0, 1.5, 1.3))))).toDF("label", "features")
    test.show(false)

    /**
      * +-----+----------------------------------+
      * |label|features                          |
      * +-----+----------------------------------+
      * |1.0  |(692,[10,20,30],[-1.0,1.5,1.3])   |
      * |0.0  |(692,[45,175,500],[-1.0,1.5,1.3]) |
      * |1.0  |(692,[100,200,300],[-1.0,1.5,1.3])|
      * +-----+----------------------------------+
      */

    //5 对模型进行测试
    val test_predict = lrModel.transform(test)
    test_predict
      .select("label", "prediction", "probability", "features")
      .show(false)

    /**
      * +-----+----------+----------------------------------------+----------------------------------+
      * |label|prediction|probability                             |features                          |
      * +-----+----------+----------------------------------------+----------------------------------+
      * |1.0  |1.0       |[0.22241243403014824,0.7775875659698517]|(692,[10,20,30],[-1.0,1.5,1.3])   |
      * |0.0  |0.0       |[0.5539602964649871,0.44603970353501293]|(692,[45,175,500],[-1.0,1.5,1.3]) |
      * |1.0  |1.0       |[0.22241243403014824,0.7775875659698517]|(692,[100,200,300],[-1.0,1.5,1.3])|
      * +-----+----------+----------------------------------------+----------------------------------+
      */

    //6 模型摘要
    val trainingSummary = lrModel.summary

    //6 每次迭代目标值
    val objectiveHistory = trainingSummary.objectiveHistory
    println("objectiveHistory:")
    objectiveHistory.foreach(loss => println(loss))

    /**
      * objectiveHistory:
      * 0.6365141682948128
      * 0.6212055977633174
      * 0.5894552698389314
      * 0.5844805633573479
      * 0.5761098112571359
      * 0.575517297029231
      * 0.5754098875805627
      * 0.5752562156795122
      * 0.5752506337221737
      * 0.5752406742715199
      * 0.5752404945106846
      */

    //6 计算模型指标数据
    val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]

    //6 AUC指标
    val roc = binarySummary.roc
    roc.show(false)

    /**
      * +---+---+
      * |FPR|TPR|
      * +---+---+
      * |0.0|0.0|
      * |0.0|1.0|
      * |1.0|1.0|
      * |1.0|1.0|
      * +---+---+
      */

    val AUC = binarySummary.areaUnderROC
    println(s"areaUnderROC: ${binarySummary.areaUnderROC}")

    //6 设置模型阈值
    //不同的阈值,计算不同的F1,然后通过最大的F1找出并重设模型的最佳阈值。
    val fMeasure = binarySummary.fMeasureByThreshold
    fMeasure.show(false)

    /**
      * +-------------------+---------+
      * |threshold          |F-Measure|
      * +-------------------+---------+
      * |0.7775875659698517 |1.0      |
      * |0.44603970353501293|0.8      |
      * +-------------------+---------+
      */

    //获得最大的F1值
    val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
    //找出最大F1值对应的阈值(最佳阈值)
    val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).select("threshold").head().getDouble(0)
    //并将模型的Threshold设置为选择出来的最佳分类阈值
    lrModel.setThreshold(bestThreshold)

    //7 模型保存与加载(发布到服务器 django 时,View 加入如下代码 + 文件)
    lrModel.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\lrmodel")
    val load_lrModel = LogisticRegressionModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\lrmodel")

  }

}

保存好的模型
这里写图片描述

3.决策树、随机森林、GBDT

比较容易理解的一个算法

package tree

import org.apache.spark.ml.feature._
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.classification.{ DecisionTreeClassifier, DecisionTreeClassificationModel }
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.{ MulticlassClassificationEvaluator, BinaryClassificationEvaluator }
import org.apache.spark.ml.{ Pipeline, PipelineModel }
import org.apache.spark.sql.SparkSession

object tree {

  def main(args: Array[String]): Unit = {

    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
    import spark.implicits._

    //1 训练样本准备
    val data = spark.read.format("libsvm").load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\sample_libsvm_data.txt")
    data.show

    //2 标签进行索引编号
    val labelIndexer = new StringIndexer().
      setInputCol("label").
      setOutputCol("indexedLabel").
      fit(data)

    // 对离散特征进行标记索引,以用来确定哪些特征是离散特征
    // 如果一个特征的值超过4个以上,该特征视为连续特征,否则将会标记得离散特征并进行索引编号
    val featureIndexer = new VectorIndexer().
      setInputCol("features").
      setOutputCol("indexedFeatures").
      setMaxCategories(4).
      fit(data)

    //3 样本划分
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

    //4 训练决策树模型
    val dt = new DecisionTreeClassifier().
      setLabelCol("indexedLabel").
      setFeaturesCol("indexedFeatures")

    //4 训练随机森林模型
    val rf = new RandomForestClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setNumTrees(10)

    //4 训练GBDT模型
    val gbt = new GBTClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setMaxIter(10)

    //5 将索引的标签转回原始标签
    val labelConverter = new IndexToString().
      setInputCol("prediction").
      setOutputCol("predictedLabel").
      setLabels(labelIndexer.labels)

    //6 构建Pipeline
    val pipeline1 = new Pipeline().
      setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

    val pipeline2 = new Pipeline().
      setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

    val pipeline3 = new Pipeline().
      setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))

    //7 Pipeline开始训练
    val model1 = pipeline1.fit(trainingData)

    val model2 = pipeline2.fit(trainingData)

    val model3 = pipeline3.fit(trainingData)

    //8 模型测试
    val predictions = model1.transform(testData)
    predictions.show(5)

    //8 测试结果
    predictions.select("predictedLabel", "label", "features").show(5)

    //9 分类指标
    // 正确率
    val evaluator1 = new MulticlassClassificationEvaluator().
      setLabelCol("indexedLabel").
      setPredictionCol("prediction").
      setMetricName("accuracy")
    val accuracy = evaluator1.evaluate(predictions)
    println("Test Error = " + (1.0 - accuracy))

    // f1
    val evaluator2 = new MulticlassClassificationEvaluator().
      setLabelCol("indexedLabel").
      setPredictionCol("prediction").
      setMetricName("f1")
    val f1 = evaluator2.evaluate(predictions)
    println("f1 = " + f1)

    // Precision
    val evaluator3 = new MulticlassClassificationEvaluator().
      setLabelCol("indexedLabel").
      setPredictionCol("prediction").
      setMetricName("weightedPrecision")
    val Precision = evaluator3.evaluate(predictions)
    println("Precision = " + Precision)

    // Recall
    val evaluator4 = new MulticlassClassificationEvaluator().
      setLabelCol("indexedLabel").
      setPredictionCol("prediction").
      setMetricName("weightedRecall")
    val Recall = evaluator4.evaluate(predictions)
    println("Recall = " + Recall)

    // AUC
    val evaluator5 = new BinaryClassificationEvaluator().
      setLabelCol("indexedLabel").
      setRawPredictionCol("prediction").
      setMetricName("areaUnderROC")
    val AUC = evaluator5.evaluate(predictions)
    println("Test AUC = " + AUC)

    // aupr
    val evaluator6 = new BinaryClassificationEvaluator().
      setLabelCol("indexedLabel").
      setRawPredictionCol("prediction").
      setMetricName("areaUnderPR")
    val aupr = evaluator6.evaluate(predictions)
    println("Test aupr = " + aupr)

    //10 决策树打印
    val treeModel = model1.stages(2).asInstanceOf[DecisionTreeClassificationModel]
    println("Learned classification tree model:\n" + treeModel.toDebugString)

    //11 模型保存与加载
    model1.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\dtmodel")
    val load_treeModel = PipelineModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\dtmodel")

  }

}

4.KMeans

和 KNN 算法有点像

package juhe

import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.sql.SparkSession

object Kmeans {

  def main(args: Array[String]): Unit = {

    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
    import spark.implicits._

    // 读取样本
    val dataset = spark.read.format("libsvm").load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\sample_kmeans_data.txt")
    dataset.show()

    // 训练 a k-means model.
    val kmeans = new KMeans().setK(2).setSeed(1L)
    val model = kmeans.fit(dataset)

    // 模型指标计算.
    val WSSSE = model.computeCost(dataset)
    println(s"Within Set Sum of Squared Errors = $WSSSE")

    // 结果显示.
    println("Cluster Centers: ")
    model.clusterCenters.foreach(println)

    // 模型保存与加载
    model.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\kmmodel")
    val load_treeModel = KMeansModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\kmmodel")
    spark.stop()


  }

}

5.LDA

可以用于文章主题分类

package juhe

import org.apache.spark.ml.clustering.{LDA, LDAModel}
import org.apache.spark.sql.SparkSession

object ldaTest {

  def main(args: Array[String]): Unit = {

    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    // 1.读取样本
    val dataset = spark.read.format("libsvm").load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\data.txt")
    dataset.show()

    // 2.训练 LDA model.
    val lda = new LDA().setK(10).setMaxIter(10)
    val model = lda.fit(dataset)

    val ll = model.logLikelihood(dataset)
    val lp = model.logPerplexity(dataset)
    println(s"The lower bound on the log likelihood of the entire corpus: $ll")
    println(s"The upper bound on perplexity: $lp")

    // 3.主题 topics.
    val topics = model.describeTopics(3)
    println("The topics described by their top-weighted terms:")
    topics.show(false)

    val aa = model.topicsMatrix
    model.estimatedDocConcentration
    model.getTopicConcentration

    // 4.测试结果.
    val transformed = model.transform(dataset)
    transformed.show(false)
    transformed.columns

    // 5.模型保存与加载
    model.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\ldamodel")

    spark.stop()

  }

}

猜你喜欢

转载自blog.csdn.net/larger5/article/details/81707571