spark3.0 用户自定义函数
重写Aggregator 方法
import org.apache.spark.{SparkConf, sql}
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator
object Spark_basic {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[*]").setAppName("waj")
val spark = SparkSession.builder().config(conf).getOrCreate()
//TODO user defined function
// 创建DataFrame
val df = spark.read.json("datas/user.json")
//注册临时表
df.createTempView("user")
// 注册udaf 函数
spark.udf.register("mymean",functions.udaf(new MyAvgUDAF()))
spark.sql("select mymean(age ) as mean from user").show()
spark.close()
}
case class Buff(var total:Long,var count:Long)
class MyAvgUDAF extends Aggregator[Long,Buff,Long]{
override def zero: Buff = new Buff(0,0L)
override def reduce(b: Buff, a: Long): Buff = {
b.count+=1
b.total+=a
b
}
override def merge(b1: Buff, b2: Buff): Buff = {
b1.total=b1.total+b2.total
b1.count=b2.count+b1.count
b1
}
override def finish(reduction: Buff): Long = reduction.total/reduction.count
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}