Spark使用UDF函数之WordCount实现

       用户定义函数(User-defined functions, UDFs)是大多数 SQL 环境的关键特性,用于扩展系统的内置功能。 UDF允许开发人员通过抽象其低级语言实现来在更高级语言(如SQL)中启用新功能。 Apache Spark 也不例外,并且提供了用于将 UDF 与 Spark SQL工作流集成的各种选项。

      本文通过自定义UDF实现WordCount案例:

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object UDF {

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder().appName("UDF").master("local[2]").getOrCreate()

    val sc: SparkContext = spark.sparkContext
    sc.setLogLevel("WARN")

    val bigData = Array("Spark", "Spark", "Hadoop", "Spark", "Hadoop", "Spark", "Spark", "Hadoop", "Spark", "Hadoop")

    val bigDataRDD: RDD[String] = sc.parallelize(bigData)

    val bigDataRDDRow: RDD[Row] = bigDataRDD.map(item => Row(item))
    val structType: StructType = StructType(Array(
      StructField("word", StringType, true)
    ))
    val bigDataDF: DataFrame = spark.createDataFrame(bigDataRDDRow,structType)

    bigDataDF.createOrReplaceTempView("bigDataTable")

    spark.udf.register("computeLength",(input:String) => input.length)
    //直接在SQL语句中使用UDF,就像使用SQL内置函数一样
    spark.sql("select word,computeLength(word) as length from bigDataTable").show()

    spark.udf.register("wordCount", new MyUDAF)
    spark.sql("select word,computeLength(word) as length, wordCount(word) as count from bigDataTable group by word").show()

    sc.stop()
    spark.stop()

  }

}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


class MyUDAF extends UserDefinedAggregateFunction{
  //该方法指定具体输入数据类型
  override def inputSchema: StructType = StructType(Array(StructField("input", StringType, true)))

  //在进行聚合操作的时候所要处理的数据的结果的类型
  override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))

  //返回的数据类型
  override def dataType: DataType = IntegerType

  //确保结果一致性
  override def deterministic: Boolean = true

  //在Aggregate之前每组数据的初始化结果
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }

  //在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
  //本地的聚合,相当于Hadood MapReduce模型中的Combiner
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  //最后在分布式节点进行Local Reduce完成后需要进行全局级别的Merge操作
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
  }

  override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}

猜你喜欢

转载自blog.csdn.net/LINBE_blazers/article/details/83245827
今日推荐