Spark Sql之UDAF自定义聚合函数

一:UDAF含义


UDAF:User Defined Aggregate Function。用户自定义聚合函数

对比UDF:
UDF,其实更多的是针对单行输入,返回一个输出
UDAF,则可以针对多行输入,进行聚合计算,返回一个输出


二:关于UDAF的一个误区


我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:

1

select max(foo) from foobar group by bar;

表示根据bar字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:

1

select max(foo) from foobar;

这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。

 
三:UDAF中:update,merge,evaluate方法的含义


update:各个分组的值内部聚合
merge:各个节点的同一分组的值聚合
evaluate:聚合各个分组的缓存值


四:自定义UDAF实战


定义:

/**
 * @author Administrator
 */
class StringCount extends UserDefinedAggregateFunction {  
  
  // inputSchema,指的是,输入数据的类型
  def inputSchema: StructType = {
    StructType(Array(StructField("str", StringType, true)))   
  }
  
  // bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
  def bufferSchema: StructType = {
    StructType(Array(StructField("count", IntegerType, true)))   
  }
  
  // dataType,指的是,函数返回值的类型
  def dataType: DataType = {
    IntegerType
  }
  
  def deterministic: Boolean = {
    true
  }

  // 为每个分组的数据执行初始化操作
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }
  
  /**
     * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
     * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
     * 大聚和发生在reduce端.
     * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
     * update的结果写入buffer中,每个分组中的每一行数据都要进行update操作
   */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }
  
/**
     * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
     * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
     * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
     * 也可以是一个节点里面的多个executor合并 reduce端大聚合
     * merge后的结果写如buffer1中
     */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)  
  }
  
  // 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)    
  }

使用:

object UDAF {
  
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
        .setMaster("local") 
        .setAppName("UDAF")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
  
    // 构造模拟数据
    val names = Array("Leo", "Marry", "Jack", "Tom", "Tom", "Tom", "Leo")  
    val namesRDD = sc.parallelize(names, 5) 
    val namesRowRDD = namesRDD.map { name => Row(name) }
    val structType = StructType(Array(StructField("name", StringType, true)))  
    val namesDF = sqlContext.createDataFrame(namesRowRDD, structType) 
    
    // 注册一张names表
    namesDF.registerTempTable("names")  
    
    // 定义和注册自定义函数
    // 定义函数:自己写匿名函数
    // 注册函数:SQLContext.udf.register()
    sqlContext.udf.register("strCount", new StringCount) 
    
    // 使用自定义函数
    sqlContext.sql("select name,strCount(name) from names group by name")  
        .collect()
        .foreach(println)  
  }
  
/* 结果:
 * [Jack,1] 
   [Tom,3] 
   [Leo,2] 
   [Marry,1]
 */

猜你喜欢

转载自blog.csdn.net/weixin_39966065/article/details/93376843