Spark Sql custom aggregation function of UDAF

A: UDAF meaning


UDAF: User Defined Aggregate Function. User-defined aggregate function

Compare the UDF:
the UDF, in fact, more for a single line is input, a return output
UDAF, it is possible for a multi-line input, the polymerization is calculated, a return output


Two: a misunderstanding about the UDAF


We may subconsciously think UDAF need to be used with group by, in fact, can be used with group by UDAF together, may not be used in conjunction with the group by, in fact, this is better understood, think of the mysql max, min and other functions ,can:

1

select max(foo) from foobar group by bar;

The bar represents a field of the packet, and then selecting the maximum value for each packet, a packet of this time there are many, this function is used to process each packet can be:

1

select max(foo) from foobar;

This situation can be seen as a whole table packet, and then selecting the maximum value in the packet (actually an entire table). Therefore, the polymerization function is actually done on the packet processing, without regard to the specific number of packets recorded.

 
Three: UDAF in: Meaning update, merge, evaluate methods


update: internal value of each packet polymerization
merge: the value of each node of the same packet polymerization
evaluate: the value of each packet buffer polymerization


Four: Custom UDAF combat


definition:

/**
 * @author Administrator
 */
class StringCount extends UserDefinedAggregateFunction {  
  
  // inputSchema,指的是,输入数据的类型
  def inputSchema: StructType = {
    StructType(Array(StructField("str", StringType, true)))   
  }
  
  // bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
  def bufferSchema: StructType = {
    StructType(Array(StructField("count", IntegerType, true)))   
  }
  
  // dataType,指的是,函数返回值的类型
  def dataType: DataType = {
    IntegerType
  }
  
  def deterministic: Boolean = {
    true
  }

  // 为每个分组的数据执行初始化操作
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }
  
  /**
     * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
     * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
     * 大聚和发生在reduce端.
     * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
     * update的结果写入buffer中,每个分组中的每一行数据都要进行update操作
   */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }
  
/**
     * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
     * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
     * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
     * 也可以是一个节点里面的多个executor合并 reduce端大聚合
     * merge后的结果写如buffer1中
     */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)  
  }
  
  // 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)    
  }

use:

object UDAF {
  
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
        .setMaster("local") 
        .setAppName("UDAF")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
  
    // 构造模拟数据
    val names = Array("Leo", "Marry", "Jack", "Tom", "Tom", "Tom", "Leo")  
    val namesRDD = sc.parallelize(names, 5) 
    val namesRowRDD = namesRDD.map { name => Row(name) }
    val structType = StructType(Array(StructField("name", StringType, true)))  
    val namesDF = sqlContext.createDataFrame(namesRowRDD, structType) 
    
    // 注册一张names表
    namesDF.registerTempTable("names")  
    
    // 定义和注册自定义函数
    // 定义函数:自己写匿名函数
    // 注册函数:SQLContext.udf.register()
    sqlContext.udf.register("strCount", new StringCount) 
    
    // 使用自定义函数
    sqlContext.sql("select name,strCount(name) from names group by name")  
        .collect()
        .foreach(println)  
  }
  
/* 结果:
 * [Jack,1] 
   [Tom,3] 
   [Leo,2] 
   [Marry,1]
 */

 

Guess you like

Origin blog.csdn.net/weixin_39966065/article/details/93376843