说明:UDAF使用的六个步骤,难点在第三步骤,代码中都有注释
1.声明用户自定义聚合函数
2.继承UserDefinedAggregateFunction
3.实现方法
4.创建聚合函数对象
5.注册函数
6.使用函数
本案例实现平均年龄的计算–avg的功能
自定义UDAF函数
//声明函数
//ageAve(age):函数的入参:age
class MyAgeAvg extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
//函数输入的数据结构
new StructType().add("age",LongType)
}
override def bufferSchema: StructType ={
//计算时候的数据结构
new StructType().add("sum",LongType).add("count",LongType)
}
//函数返回的数据类型
override def dataType: DataType = DoubleType
//函数是否稳定,相同结果是否有相同输出
override def deterministic: Boolean = true
//计算之前缓冲区的初始化MutableAggregationBuffer:就是数组类型
override def initialize(buffer: MutableAggregationBuffer): Unit ={
//无法通过名称获取
//sum
buffer(0)=0L
//count
buffer(1)=0L
}
//根据查询结果更新缓冲区数据数据
//input 参考之前inputSchema
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//之前值+当前值
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
//多节点缓冲区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//sum
buffer1(0) = buffer1.getLong(0)+ buffer2.getLong(0)
//count
buffer1(1) = buffer1.getLong(1)+ buffer2.getLong(1)
}
//计算最终结果
override def evaluate(buffer: Row): Any ={
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}
UDAF 的使用
object SparkSql05_UDAF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[*]").setAppName("")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
val df1: DataFrame = spark.read.json("./json")
df1.show()
/* |age| name|
+---+--------+
| 20|zhangsan|
| 21| lisi|
| 22| wangwu|
+---+--------+*/
//创建视图
df1.createOrReplaceTempView("user")
//注册UDAF函数
val udaf = new MyAgeAvg
spark.udf.register("ageAve",udaf)
//使用函数
val df2 = spark.sql("select ageAve(age) from user")
df2.show()
//|myageavg(age)|
//+-------------+
//| 21.0|
//+-------------+
spark.stop()
}
}