Spark实现tf-idf

scala代码:

package offline

import org.apache.spark.ml.feature.{HashingTF, IDF}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

object TfIdfTransform {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("tf-idf")
      .enableHiveSupport()
      .getOrCreate()
//    allfiles.txt 所有的文本已经放大一个文件中
    val df = spark.sql("select sentence from badou.news_seg")
    val df_seg = df.selectExpr("split(split(sentence,'##@@##')[0],' ') as seg")
    val doc_size = df_seg.count()
//    spark中自带的tf-idf方法
//    hashingtf 把文章[word1,word2,...] => sparseVector(2^18,[word hash code(index)][word count])
    val hashingtf = new HashingTF().setBinary(false)
      .setInputCol("seg").setOutputCol("feature_tf")
      .setNumFeatures(1<<18)  //26,2144
//    val hashingtf_bn = new HashingTF().setBinary(true)
//      .setInputCol("seg").setOutputCol("feature_tf")
//      .setNumFeatures(1<<18)
//
//    val df_tf_bn = hashingtf_bn.transform(df_seg).select("feature_tf")
    val df_tf = hashingtf.transform(df_seg).select("feature_tf")

//    idf 对word进行idf加权
    val idf = new IDF().setInputCol("feature_tf").setOutputCol("feature_tfidf")
      .setMinDocFreq(2)

    val idfModel =idf.fit(df_tf)
    val df_tfIdf = idfModel.transform(df_tf).select("feature_tfidf")

//    自己实现tf-idf
//    1. doc Freq 文档频率计算 -> 所有文章的单词集合(词典)
    val setUDF = udf((str:String)=>str.split(" ").distinct)
    val df_set = df.withColumn("words_set",setUDF(col("sentence")))
    val docFreq_map = df_set.select(explode(col("words_set")).as("word"))
      .groupBy("word").count().rdd.map(x=>(x(0).toString,x(1).toString))
      .collectAsMap()
    val wordEncode = docFreq_map.keys.zipWithIndex.toMap  // [0-42362]
    val dictSize = docFreq_map.size
//    共有4,2363
//    docFreq.count()
//    2. term Freq 词频计算 对每篇文章(一行数据)统计词频
    val mapUDF = udf{(str:String)=>
      val tfMap = str.split("##@@##")(0).split(" ")
        .map((_,1L)).groupBy(_._1).mapValues(_.length)

      val tfIDFMap = tfMap.map{x=>
        val idf_v = math.log10(doc_size.toDouble/(docFreq_map.getOrElse(x._1,"0.0").toDouble+1.0))
        (wordEncode.getOrElse(x._1,0),x._2.toDouble * idf_v)
      }

      Vectors.sparse(dictSize,tfIDFMap.toSeq)
    }
    val dfTF = df.withColumn("tf_idf",mapUDF(col("sentence")))
  }

}

猜你喜欢

转载自www.cnblogs.com/xumaomao/p/12763375.html