Based ml of spark Chinese text classification (Naive Bayes)

Based ml of spark Chinese text classification (Naive Bayes)

Acquisition processes and corpus Chinese word can refer to https://www.cnblogs.com/DismalSnail/p/11801742.html
here show how to use the new machine learning package ml of spark, segmentation tools HanLP (see https: / /github.com/hankcs/HanLP ) weighting terms as TF-IDF, the classifier is naive Bayes classifier, this experiment will train set Fudan Chinese corpus with the test set and one.
***

package com.teligen.subject.ML

import java.io.File

import com.hankcs.hanlp.HanLP
import org.apache.commons.io.FileUtils
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, Tokenizer}
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.collection.mutable.ListBuffer

/**
 * 朴素贝叶斯训练示例
 */
object NBClassDemo {

  //将分词后的词语转为间隔为空格的字符串
  def toStringList(termString: String): String = {
    termString.replace("[", "").replace("]", "").replace(",", "")

  }

  //存储代表标签的Double值和分词后的字符串
  //注意这里的Double必须从0.0开始,顺序增长 0.0 1.0 2.0 ... ,不然即使预测正确,标签的Double值也对不上,正确率的计算会
  //出错
  val labelAndSentenceSeq: ListBuffer[(Double, String)] = ListBuffer[(Double, String)]()

  //分词函数
  def segment(corpusPath: String): Unit = {
    //代表标签的Double,从0.0开始
    var count: Double = 0.0
    //设置hanLP分词结果不带词性,这样toString后就不会有 词性字符了,方便构建词向量

    HanLP.Config.ShowTermNature = false
    //打开根目录
    val corpusDir: File = new File(corpusPath)
    //类别目录
    for (classDir: File <- corpusDir.listFiles()) {
      //文件
      for (text <- classDir.listFiles()) {
        //将标签Double,和分词后的字符串存入labelAndSentenceSeq
        labelAndSentenceSeq.append(Tuple2(count,
          //对HanLP.segment().toString修改,使两个词之间为空格
          toStringList(
            //分词
            HanLP.segment(
              //以字符串的形式读取文本
              FileUtils.readFileToString(text)
                .replace("\r\n", "")//去换行、回车
                .replace("\r", "")//去回车
                .replace("\n", "")//去换行
                .replace(" ", "")//去空格
                .replace("\u3000", "")//去全角空格(中文空格)
                .replace("\t", "")//去制表符
                .replaceAll(s"\\pP|\\pS|\\pC|\\pN|\\pZ", "")//通过Unicode的类别相关正则,去掉各种符号
                .trim
              //分类器的toSting,单词之间使逗号+空格,需要进一步处理
            ).toString)))
      }
      //改变标签Label
      count = count + 1.0
    }
  }

  //构建以TF-IDF为权重的词向量
  def tfIdf(spark: SparkSession): DataFrame = {
    //将标签Double和分词后的字符串转为DataFrame
    val sentenceData: DataFrame = spark.createDataFrame(labelAndSentenceSeq.toSeq).toDF("label", "sentence")
    
    //将字分词后的字符串分割为一个个词语,Tokenizer()只能分割以空格间隔的字符串,
    // RegexTokenizer功能更强大,详情可以点进Tokenizer()源码查看
    
    //新建sentence --> words分割器
    val tokenizer: Tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words")
    //进行分割
    //这里如果不select(),则每一步的计算结果都存储在DataFrame,导致DataFrame很大,很容易造成 java heap space 异常
    val wordsData: DataFrame = tokenizer.transform(sentenceData).select("label", "words")
    
    //新建 words --> rawFeatures HasingTF类
    val hashingTF: HashingTF = new HashingTF()
      .setInputCol("words").setOutputCol("rawFeatures")
    
    //执行计算,获得每个语句中每词语的词频即 TF(Term Frequency)
    val featurizedData: DataFrame = hashingTF.transform(wordsData).select("label", "rawFeatures")
    
    //新建rawFeatures --> features IDF类
    val idf: IDF = new IDF().setInputCol("rawFeatures").setOutputCol("features")
    //计算IDF (Inverse Document Frequency)
    val idfModel: IDFModel = idf.fit(featurizedData)
    //计算TF-IDF
    idfModel.transform(featurizedData).select("label", "features")
  }

  //训练和预测函数
  def trainAndPredict(ifIdfData: DataFrame) = {
    //按比例选取测试集和训练集
    val Array(trainingData, testData) = ifIdfData.randomSplit(Array(0.7, 0.3), seed = 1234L)
    //训练朴素贝叶斯分类器
    val model = new NaiveBayes().fit(trainingData)
    //预测
    val predictions = model.transform(testData)
    //展示测试结果,50条
    predictions.show(50)

    //测试结果评估
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    //测试结果准确率
    val accuracy = evaluator.evaluate(predictions)
    println(s"Test set accuracy = $accuracy")
  }

  def main(args: Array[String]): Unit = {
    //新建spark上下文
    val spark = SparkSession.builder().master("local[2]").appName("NBC").getOrCreate()
    //分词
    segment("./corpus/all_corpus/")
    //训练和预测
    trainAndPredict(tfIdf(spark))
  }
}

Guess you like

Origin www.cnblogs.com/DismalSnail/p/11802281.html