版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_37050372/article/details/82981177
什么是UDAF?就是输入N行得到一个结果,属于聚合类的。
接下来我们就写一个求几何平均数的一个自定义聚合函数的例子
我们从开头写起,先来看看需要进行计算的数如何产生:
package com.test.SparkSQL
import java.lang
import org.apache.spark.sql.{Dataset, SparkSession}
object UDAFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("UDAFDemo")
.master("local[*]")
.getOrCreate()
val ds: Dataset[lang.Long] = spark.range(1,10)
ds.show()
}
}
生成结果:
接下来我们使用自定义聚合函数计算几何平均数:
package com.test.SparkSQL
import java.lang
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession, types}
object UDAFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("UDAFDemo")
.master("local[*]")
.getOrCreate()
val ds: Dataset[lang.Long] = spark.range(1,10)
//ds.show()
ds.createTempView("v_num")
val gm = new GeometriMean
spark.udf.register("gm",gm)
spark.sql("select gm(id) as gm from v_num").show()
}
}
class GeometriMean extends UserDefinedAggregateFunction{
//定义输入数据的类型
override def inputSchema: StructType = StructType(List(StructField("value",DoubleType)))
//定义存储聚合运算时产生的中间数据结果的类型
override def bufferSchema: StructType = StructType(
List(
StructField("count",LongType),
StructField("product",DoubleType)
)
)
//表名了UDAF函数的返回值类型
override def dataType: DataType = DoubleType
//用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
override def deterministic: Boolean = true
//对聚合运算中间结果的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 1.0
}
//每处理一条数据都要执行update,相当于局部计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Long](0)+1
buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
}
//负责合并两个聚合运算的buffer,再将其存储到MutableAggregationBuffer,相当于全局计算
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
}
//完成对聚合Buffer值的运算,得到最后的结果
override def evaluate(buffer: Row): Any = {
math.pow(buffer.getDouble(1),1.toDouble/buffer.getLong(0))
}
}
运行结果: