朴素贝叶斯算法,对男女分类

使用朴素贝叶斯算法进行男女分类

训练模型,并进行保存

原数据:
在这里插入图片描述

maven 依赖
<dependencies>


        <!-- 导入spark sql的依赖 -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.11</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.scala-lang.modules</groupId>
            <artifactId>scala-parser-combinators_2.11</artifactId>
            <version>1.0.1</version>
        </dependency>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-compiler</artifactId>
            <version>2.11.8</version>
        </dependency>

        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.11.8</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.2.0</version>
        </dependency>

        <!-- 指定hadoop-client API的版本 -->
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-client</artifactId>
            <version>${hadoop.version}</version>
        </dependency>
    
       

    </dependencies>

import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg.{SparseVector, Vectors}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}

/**
  * @Author: zxl
  * @Date: 2019/1/15 13:12
  * @Version 1.0
  */
object ModelTrain {
  
  case class RawDataRecord(label: String, text: String)

  def main(args : Array[String]) {

    val config = new SparkConf().setAppName("createModel").setMaster("local[4]")
    val sc =new  SparkContext(config)
    val sqlContext = new SQLContext(sc)
    //开启RDD隐式转换,利用.toDF方法自动将RDD转换成DataFrame;
    import sqlContext.implicits._

    val TrainDf = sc.textFile("D:\\Yue\\hux\\aaa.txt").map {
      x =>
        val data = x.split("\t")
        RawDataRecord(data(0),data(1))
    }.toDF()

    val TestDf= sc.textFile("D:\\Yue\\hux\\a1.txt").map {
      x =>
        val data = x.split("\t")
        RawDataRecord(data(0),data(1))
    }.toDF()

    //tokenizer分解器,把句子划分为词语
    val TrainTokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
    val TrainWords = TrainTokenizer.transform(TrainDf)
    val TestTokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
    val TestWords = TestTokenizer.transform(TestDf)
    //特征抽取,利用TF-IDF
    val TrainHashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(5000)
    val TrainData = TrainHashingTF.transform(TrainWords)
    val TestHashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(5000)
    val TestData = TestHashingTF.transform(TestWords)

    val TrainIdf: IDF = new IDF().setInputCol("rawFeatures").setOutputCol("features")
    val TrainIdfmodel: IDFModel = TrainIdf.fit(TrainData)
    val TrainForm: DataFrame = TrainIdfmodel.transform(TrainData)

    val TestIdf: IDF = new IDF().setInputCol("rawFeatures").setOutputCol("features")
    val TestIdfModel: IDFModel = TestIdf.fit(TestData)
    val TestForm = TestIdfModel.transform(TestData)



    val TrainDF = TrainForm.select($"label",$"features").map(x=>{

      val label = x.getAs[String](0)
      val features= x.getAs[SparseVector](1)
//把数据转换成朴素贝叶斯格式

      LabeledPoint(label.toDouble,Vectors.dense(features.toArray))
    })


    val TestDF = TestForm.select($"label",$"features").map(x=>{
      val label = x.getAs[String](0)
      val features= x.getAs[SparseVector](1)

      LabeledPoint(label.toDouble,Vectors.dense(features.toArray))
    })

    //建立模型
    val model =new NaiveBayes().fit(TrainDF)
    val predictions = model.transform(TestDF)
    predictions.show()
    //评估模型
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictions)
    println("准确率:"+accuracy)
    //保存模型
    model.write.overwrite().save("model")
  }
}





import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable
/*import org.ansj.recognition.impl.StopRecognition*/
import org.apache.spark.ml.classification.NaiveBayesModel
/**
  * @Author: zxl
  * @Date: 2019/1/15 14:19
  * @Version 1.0
  */
object Demo {


  case class RawDataRecord(label: String)
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[4]").setAppName("Demo")
    val sc = new SparkContext(conf)
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._
    val frdd = sc.textFile("D:\\hux\\test.txt")
 /*   val filter = new StopRecognition()
    filter.insertStopNatures("w") //过滤掉标点*/
    val rdd = frdd.map(RawDataRecord(_)).toDF()

    val tokenizer = new Tokenizer().setInputCol("label").setOutputCol("words")
    val wordsData = tokenizer.transform(rdd)

    //setNumFeatures的值越大精度越高,开销也越大
    val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(5000)
    val PreData = hashingTF.transform(wordsData)

    val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
    val idfModel = idf.fit(PreData)
    val PreModel = idfModel.transform(PreData)
    //加载模型
    val model =NaiveBayesModel.load("model")
    model.transform(PreModel).select("words","prediction").map(row=>{
      val s = row.getAs[Double]("prediction")
      val sex = if (s==0.0) "女" else "男"
      val words = row.getAs[scala.collection.mutable.WrappedArray[String]]("words")
      Raw(words,sex)
    }).show()

  }
}

case class Raw(word: mutable.WrappedArray[String], sex:String)

结果

在这里插入图片描述
借鉴博客 http://www.cnblogs.com/feiyumo/p/9230186.html

猜你喜欢

转载自blog.csdn.net/Lu_Xiao_Yue/article/details/86493921