Spark2.0 机器学习 ML 库:机器学习工作流、交叉 - 验证方法

一、前言

本文为 Spark 机器学习的一个逻辑斯蒂回归案例,通过它,可以先大体地熟悉 Spark ML 机器学习的整个流程
这对以后深入熟悉 Spark ML 局部的具体内容,有很大的帮助。

现在的垃圾短信拦截机制,以及笔者在做的一个大创项目里边的疾病分类,都是这个的案例的细化与拓展。

二、代码

下面通过机器学习,训练模型,使其能识别与 love 有关的英文句子(中文的话要先分词,可以借助 jieba 分词
..
..*    ★
★      *
★ .’
‘*.    .
` . .

import org.apache.spark.ml.feature._
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession

case class Love(id: Long, text: String, label: Double)

case class Test(id: Long, text: String)

/**
  * Spark ML 机器学习工作流
  */
object mlBegin {

  /**
    * ML Pipelines 管道
    *
    * @param args
    */
  def main(args: Array[String]): Unit = {

    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
    import spark.implicits._ // 缺少则报错:Unable to find encoder for type stored in a Dataset.  Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._  Support for serializing other types will be added in future releases

    // 1.训练样本:与 love 有关的=1.0,无关=0.0
    val training = spark.createDataFrame(
      Seq(
        Love(1L, "I love you", 1.0),
        Love(2L, "There is nothing to do", 0.0),
        Love(3L, "Work hard and you will success", 0.0),
        Love(4L, "We love each other", 1.0),
        Love(5L, "Where there is love, there are always wishes", 1.0),
        Love(6L, "I love you not because who you are,but because who I am when I am with you", 1.0),
        Love(7L, "Never frown,even when you are sad,because youn ever know who is falling in love with your smile", 1.0),
        Love(8L, "Whatever is worth doing is worth doing well", 0.0),
        Love(9L, "The hard part isn’t making the decision. It’s living with it", 0.0),
        Love(10L, "Your happy passer-by all knows, my distressed there is no place hides", 0.0),
        Love(11L, "When the whole world is about to rain, let’s make it clear in our heart together", 0.0)
      )
    ).toDF()
    training.show()
    /**
      * +---+--------------------+-----+
      * | id|                text|label|
      * +---+--------------------+-----+
      * |  1|          I love you|  1.0|
      * |  2|There is nothing ...|  0.0|
      * |  3|Work hard and you...|  0.0|
      * |  4|  We love each other|  1.0|
      * |  5|Where there is lo...|  1.0|
      * |  6|I love you not be...|  1.0|
      * |  7|Never frown,even ...|  1.0|
      * |  8|Whatever is worth...|  0.0|
      * |  9|The hard part isn...|  0.0|
      * | 10|Your happy passer...|  0.0|
      * | 11|When the whole wo...|  0.0|
      * +---+--------------------+-----+
      */

    // 2.参数设置:tokenizer、hashingTF、lr
    val tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
    val hashingTF = new HashingTF()
      .setNumFeatures(1000)
      .setInputCol(tokenizer.getOutputCol)
      .setOutputCol("features")
    val lr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.01)
    val pipeline = new Pipeline()
      .setStages(Array(tokenizer, hashingTF, lr))

    // 3.训练模型
    val model = pipeline.fit(training)

    // 4.测试数据
    val test = spark.sqlContext.createDataFrame(Seq(
      Test(1L, "You love me"),
      Test(2L, "Your happy passer-by all knows, my distressed there is no place hides"),
      Test(3L, "You may be out of my sight, but never out of my mind"),
      Test(4L, "Do you like me")
    )).toDF()
    test.show()

    /**
      * +---+--------------------+
      * | id|                text|
      * +---+--------------------+
      * |  1|         You love me|
      * |  2|Your happy passer...|
      * |  3|You may be out of...|
      * |  4|      Do you like me|
      * +---+--------------------+
      */

    // 5.模型预测
    model.transform(test).
      select("id", "text", "probability", "prediction").
      collect().foreach {
      case Row(id: Long, text: String, prob: Vector, prediction: Double) => println(s"($id, $text) --> prob=$prob, prediction=$prediction")
    }

    /**
      * (1, You love me) --> prob=[0.060702374718936816,0.9392976252810632], prediction=1.0
      * (2, Your happy passer-by all knows, my distressed there is no place hides) --> prob=[0.9941770676282915,0.005822932371708409], prediction=0.0
      * (3, You may be out of my sight, but never out of my mind) --> prob=[0.5164321960230588,0.4835678039769412], prediction=0.0
      * (4, Do you like me) --> prob=[0.6971512913787795,0.30284870862122054], prediction=0.0
      */

    // 6.保存 训练方式 pipeline
    pipeline.write.overwrite().save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\pipelineTest")

    // 7.保存 预测模型 model
    model.write.overwrite().save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\modelTest")

    // 8.读取 预测模型 model
    val sameModel = PipelineModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\modelTest")

    // 9.读取 预测模型 model 验证
    sameModel.transform(test).
      select("id", "text", "probability", "prediction").
      collect().foreach {
      case Row(id: Long, text: String, prob: Vector, prediction: Double) => println(s"again($id, $text) --> prob=$prob, prediction=$prediction")
    }

    /** 全对
      * again(1, You love me) --> prob=[0.060702374718936816,0.9392976252810632], prediction=1.0
      * again(2, Your happy passer-by all knows, my distressed there is no place hides) --> prob=[0.9941770676282915,0.005822932371708409], prediction=0.0
      * again(3, You may be out of my sight, but never out of my mind) --> prob=[0.5164321960230588,0.4835678039769412], prediction=0.0
      * again(4, Do you like me) --> prob=[0.6971512913787795,0.30284870862122054], prediction=0.0
      */
  }

}

由输出结果,可知,识别的正确率,为 100%,但是通过 prob 中选择的中,如第三个预测句子:

prob=[0.5164321960230588,0.4835678039769412]    //机器觉得 0.52 可能性为与 love 有关,0.48 为与 love 无关

机器现在有 “连蒙带猜” 的成分,若训练的样本越多(有争议的句子就不要了 ~ ),识别率将越高!

三、交叉-验证

改自上面的代码

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{ HashingTF, Tokenizer }
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tuning.{ CrossValidator, ParamGridBuilder }
import org.apache.spark.sql._
import org.apache.spark.sql.SparkSession

case class Love(id: Long, text: String, label: Double)

case class Test(id: Long, text: String)

object CrossValidatorTest {

  def main(args: Array[String]): Unit = {
    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
    import spark.implicits._ // 缺少则报错:Unable to find encoder for type stored in a Dataset.  Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._  Support for serializing other types will be added in future releases

    // 1.训练样本
    val training = spark.createDataFrame(
      Seq(
        Love(1L, "I love you", 1.0),
        Love(2L, "There is nothing to do", 0.0),
        Love(3L, "Work hard and you will success", 0.0),
        Love(4L, "We love each other", 1.0),
        Love(5L, "Where there is love, there are always wishes", 1.0),
        Love(6L, "I love you not because who you are,but because who I am when I am with you", 1.0),
        Love(7L, "Never frown,even when you are sad,because youn ever know who is falling in love with your smile", 1.0),
        Love(8L, "Whatever is worth doing is worth doing well", 0.0),
        Love(9L, "The hard part isn’t making the decision. It’s living with it", 0.0),
        Love(10L, "Your happy passer-by all knows, my distressed there is no place hides", 0.0),
        Love(11L, "When the whole world is about to rain, let’s make it clear in our heart together", 0.0)
      )
    ).toDF()
    training.show()
    /**
      * +---+--------------------+-----+
      * | id|                text|label|
      * +---+--------------------+-----+
      * |  1|          I love you|  1.0|
      * |  2|There is nothing ...|  0.0|
      * |  3|Work hard and you...|  0.0|
      * |  4|  We love each other|  1.0|
      * |  5|Where there is lo...|  1.0|
      * |  6|I love you not be...|  1.0|
      * |  7|Never frown,even ...|  1.0|
      * |  8|Whatever is worth...|  0.0|
      * |  9|The hard part isn...|  0.0|
      * | 10|Your happy passer...|  0.0|
      * | 11|When the whole wo...|  0.0|
      * +---+--------------------+-----+
      */

    // 2. 参数设置:tokenizer、hashingTF、lr
    val tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
    val hashingTF = new HashingTF()
      .setNumFeatures(1000)
      .setInputCol(tokenizer.getOutputCol)
      .setOutputCol("features")
    val lr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.01)
    val pipeline = new Pipeline()
      .setStages(Array(tokenizer, hashingTF, lr))

    // 3. 建立网格搜索
    val paramGrid = new ParamGridBuilder()
      .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
      .addGrid(lr.regParam, Array(0.1, 0.01))
      .build()

    // 4. 建立一个交叉验证的评估器,设置评估器的参数
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(2) // Use 3+ in practice

    // 5. 运行交叉验证评估器,得到最佳参数集的模型.
    val cvModel = cv.fit(training)

    // 6.测试数据
    val test = spark.sqlContext.createDataFrame(Seq(
      Test(1L, "You love me"),
      Test(2L, "Your happy passer-by all knows, my distressed there is no place hides"),
      Test(3L, "You may be out of my sight, but never out of my mind"),
      Test(4L, "Do you like me")
    )).toDF()
    test.show()

    /**
      * +---+--------------------+
      * | id|                text|
      * +---+--------------------+
      * |  1|         You love me|
      * |  2|Your happy passer...|
      * |  3|You may be out of...|
      * |  4|      Do you like me|
      * +---+--------------------+
      */

    // 7.模型预测
    cvModel.transform(test).
      select("id", "text", "probability", "prediction").
      collect().foreach {
      case Row(id: Long, text: String, prob: Vector, prediction: Double) => println(s"($id, $text) --> prob=$prob, prediction=$prediction")
    }
    /** 只对一半
      * (1, You love me) --> prob=[0.6955159095851398,0.3044840904148603], prediction=0.0
      * (2, Your happy passer-by all knows, my distressed there is no place hides) --> prob=[0.8156738426968556,0.18432615730314444], prediction=0.0
      * (3, You may be out of my sight, but never out of my mind) --> prob=[0.4116382164564237,0.5883617835435763], prediction=1.0
      * (4, Do you like me) --> prob=[0.6723035595096153,0.32769644049038477], prediction=0.0
      */
  }

}

正确率似乎下降了 ~

四、其他

这个例子相当有趣,读者可以试着去训练分类 与友谊 friendship 有关的英文句子
或者进阶,去训练分类 与友谊 friendship 有关中文句子

参考文章:
利用开发工具IntelliJ IDEA编写Spark应用程序(Scala+Maven)
Spark入门:构建一个机器学习工作流

猜你喜欢

转载自blog.csdn.net/larger5/article/details/81659967