[Spark]-结构化数据查询之自定义UDAF

1.自定义弱类型UDAF

  1.1 UDAF定义

    弱类型UDAF继承实现 UserDefinedAggregateFunction 抽象类

    override def inputSchema: StructType = 输入schema

    override def bufferSchema: StructType = 聚合过程schema

    override def dataType: DataType = 返回值类型

    override def deterministic: Boolean = 是否固定返回值类型

    override def initialize(buffer: MutableAggregationBuffer): Unit = 初始化函数,用来初始化基准值

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 分区内元素如何聚合

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 分区之间如何聚合

    override def evaluate(buffer: Row): Any = 聚合结果计算

    整个UDAF处理过程,非常类似RDD的aggregate算子

      aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U

    一个自定义求平均数UDAF例子

            object MyAvgUDAF extends UserDefinedAggregateFunction
            {
            //输入schema
            override def inputSchema: StructType = StructType(StructField("input",DoubleType)::Nil);
            //聚合过程schema
            override def bufferSchema: StructType = StructType(StructField("Sum",DoubleType)::StructField("Count",LongType)::Nil)
            //返回值类型
            override def dataType: DataType = DoubleType
            
            //是否固定返回值类型
            override def deterministic: Boolean = true
            
            //初始化函数
            override def initialize(buffer: MutableAggregationBuffer): Unit = {
                //设定聚合基准初始值 aggregate算子((0,0))的部分
                buffer(0) = 0; //总和0
                buffer(1) = 0; //个数0
            }
            
            override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                //行第一列(Row[0])是否为null
                if(input.isNullAt(0)){
                //aggregate算子....(seqOp: (U, T) => U 部分
                buffer(0)= buffer.getDouble(0)+ input.getDouble(0);
                buffer(1) =buffer.getLong(1)+1;
                }
            }
            
            override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
                //aggregate算子....combOp: (U, U) => U 部分
                buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0);
                buffer2(1) = buffer1.getLong(1) + buffer2.getLong(1);
            }
            
            override def evaluate(buffer: Row): Any = buffer.getDouble(0)/buffer.getLong(1);
            }

猜你喜欢

转载自www.cnblogs.com/NightPxy/p/9269171.html