spark 机器学习(ml pipeline)

1、业务目标,通过训练模型给待处理数据打上标签

给定训练样本中对包含hello的字符串文本打上标签1,否则打上0.
期望,通过训练模型用机器学习的方式对待测试数据做同样的操作。

2、训练样本sample.txt

三列(id,文本,标签),hello文本标签为1
0,why hello world JAVA,1.0
1,what llo java jsp,0.0
2,test hello2 scala,0.0
3,abc spark hello,1.0
4,j hello c#,1.0
5,i java hell spark,0.0
6,i java hell spark

3、待测试数据样本w1.txt

0,hello world
1,hello java test num
2,test hello scala
3,j hello spark
4,abc hello c#
5,hell java spark
6,hello java spark
7,num he he java spark
8,hello2 java spark
9,hello do some thing java spark
10,world hello java spark

4、code

4.1依赖

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.11</artifactId>
            <version>2.4.4</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.11</artifactId>
            <version>2.4.4</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_2.11</artifactId>
            <version>2.4.4</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.4.4</version>
        </dependency>

4.2 实现

package com.home.spark.ml

import org.apache.spark.SparkConf
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.linalg.Vector

/**
  * @Description: 机器学习,训练样本数据,给生产数据打标签
  *              样本训练数据中带有hello的文本,打标签为1,否则为0
  *              通过训练模型,我们希望待测试数据同样用这种方式打上标签。
  **/
object Ex_label {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf(true).setMaster("local[*]").setAppName("spark ml label")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    val error_count = spark.sparkContext.longAccumulator("error_count")

    //载入训练数据,数据手工训练,给带有hello的数据打上1.0的标签,给没有hello的数据打上0.0
    val lineRDD: RDD[String] = spark.sparkContext.textFile("input/sample.txt")

    //rdd转换成df或者ds需要SparkSession实例的隐式转换
    //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名
    import spark.implicits._


    //生成训练数据,标签数据必须为double
    val training: DataFrame = lineRDD.map(line => {
      val strings: Array[String] = line.split(",")
      if (strings.length == 3) {
        (strings(0), strings(1), strings(2).toDouble)
      }
      else {
        error_count.add(1)
        ("-1", strings.mkString(" "), 0.0)
      }

    }).filter(s => !s._1.equals("-1"))
      .toDF("id", "text", "label")

    training.printSchema()
    training.show()

    println(s"错误数据计数 : ${error_count.value}")


    //Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
    //Transformer,转换器,字符解析,转换输入文本,以空格分隔,转成小写词
    val tokenizer: Tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")

    //Transformer,转换器,哈希转换,以哈希方式将词转换成词频,转成特征向量
    val hashTF: HashingTF = new HashingTF().setNumFeatures(1000)
      .setInputCol(tokenizer.getOutputCol).setOutputCol("features")

    //Estimator,预测器或评估器,逻辑回归,10次最大迭代
    val lr: LogisticRegression = new LogisticRegression().setMaxIter(10).setRegParam(0.01)

    //预测器通过 fit() 方法,接收一个 DataFrame 并产出一个模型
    //封装流水线,包含两个转换器(实际包含两个模型),一个评估器(包含一个算法)
    //因为还有评估器,所以需要训练生成最终模型
    val pipeline: Pipeline = new Pipeline().setStages(Array(tokenizer, hashTF, lr))


    // Fit the pipeline to training documents.
    //训练,生成最终模型
    val model: PipelineModel = pipeline.fit(training)

    // 可以选择保存模型到磁盘
    model.write.overwrite().save("/tmp/spark-logistic-regression-model")
    // 重新加载回来
    //    val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model")


    // 保存未训练(unfit)的流水线到底盘
    //    pipeline.write.overwrite().save("/tmp/unfit-lr-model")

    //重新加载流水线
    //    val samePipeline = Pipeline.load("/tmp/unfit-lr-model")


    //加载待分析数据
    val testRDD: RDD[String] = spark.sparkContext.textFile("input/w1.txt")
    val test: DataFrame = testRDD.map(line => {
      val strings: Array[String] = line.split(",")
      if (strings.length == 2) {
        (strings(0), strings(1))
      }
      else {
        //        error_count.add(1)
        ("-1", strings.mkString(" "))
      }

    }).filter(s => !s._1.equals("-1"))
      .toDF("id", "text")

    //对给定数据进行预测
    model.transform(test)
      .select("id", "text", "probability", "prediction")
      .collect()
      .foreach {
        case Row(id: String, text: String, prob: Vector, prediction: Double) =>
          println(s"($id, $text) --> prob=$prob, prediction=$prediction")
      }

    spark.stop()

  }
}

/* 运行结果
(0, hello world) --> prob=[0.02467400198786794,0.975325998012132], prediction=1.0
(1, hello java test num) --> prob=[0.48019580016300345,0.5198041998369967], prediction=1.0
(2, test hello scala) --> prob=[0.6270035488150222,0.3729964511849778], prediction=0.0 //这条分析错误,样本数据不够,或者样本干扰
(3, j hello spark) --> prob=[0.031182836719302286,0.9688171632806978], prediction=1.0
(4, abc hello c#) --> prob=[0.006011466954209337,0.9939885330457907], prediction=1.0
(5, hell java spark) --> prob=[0.9210765571223096,0.07892344287769032], prediction=0.0
(6, hello java spark) --> prob=[0.1785326777978406,0.8214673222021593], prediction=1.0
(7, num he he java spark) --> prob=[0.6923088930430097,0.30769110695699026], prediction=0.0
(8, hello2 java spark) --> prob=[0.9016001424620457,0.09839985753795444], prediction=0.0
(9, hello do some thing java spark) --> prob=[0.1785326777978406,0.8214673222021593], prediction=1.0
(10, world hello java spark) --> prob=[0.05144953292014106,0.9485504670798589], prediction=1.0
*/
//probability 是预测概率向量,第一个值是不符合度,第二个值是符合度,
//prediction的标签取决于模型的阀值设置严格度
 

猜你喜欢

转载自www.cnblogs.com/asker009/p/12145408.html