package day01
import org.apache.spark.sql.{Row, types}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* 自定义一个聚合方法
* 首先要定义一个类继承UserDefinedAggregateFunction
* 重写8个方法
*
*/
class GeometricMean extends UserDefinedAggregateFunction{
//UDAF与DataFrame列有关的输入样式,StructField的名字并没有特别要求,完全可以认为是两个内部结构的列名站位符
//至于UDAF具体要操作DataFrame的那个列,取决于调用者,但前提是数据类型必须符合事先的设置,如这里的Double
override def inputSchema: StructType = StructType(List(StructField("value",DoubleType)))
//定义存储聚合运算时产生的中间数据结果的Schema
override def bufferSchema: StructType = StructType(List(
//参与运算的个数
StructField("count",LongType),
//参与运算的乘积
StructField("product",DoubleType)
))
//标明了UDAF函数的返回值类型
override def dataType: DataType = DoubleType
//用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
override def deterministic: Boolean = true
//对聚合运算中间结果的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//参与运算数字的个数
buffer(0)=0L
//参与相乘的值
buffer(1)=1.0
}
//每处理一条数据都要执行update
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//局部计算
buffer(0) =buffer.getAs[Long](0) +1
buffer(1) =buffer.getAs[Double](1) * input.getAs[Double](0)
}
//负责合并两个聚合运算的buffer,再将其存储到MutableAggregationBuffer中
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//全局 累加 累乘
buffer1(0)=buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1)=buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
}
//完成对聚合Buffer值得运算,得到最后的结果
override def evaluate(buffer: Row): Any = {
math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
}
}
object GeometricMean{
def main(args: Array[String]): Unit = {
// val r =Math.pow(1*2*3*4*5*6*7*8*9,1.toDouble/9)
val r =Math.pow(3,1.toDouble/2)
println(r)
}
}
package day01
import java.lang
import org.apache.spark.sql.{Dataset, SparkSession}
object UDAFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("UDAFDemo")
.master("local[*]")
.getOrCreate()
//df.show()列叫id
val df: Dataset[lang.Long] = spark.range(1,10)
val gm = new GeometricMean
//写sql需要注册视图
// df.createTempView("v_num")
spark.udf.register("gm",gm)
// spark.sql("SELECT gm(id) as gm from v_num").show()
//不用视图来弄,直接使用算子
// df.select(expr("gm(id) as GeometricMean")).show()
// df.groupBy().agg(gm(col("id")).as("GeometricMean")).show
}
}
SparkSQL的自定义函数
UDF 调用函数式输入一行,返回一个值, 1->1 substring
UDAF 调用函数时输入N行,返回一个值 N-> 1 count(*)
使用UDFs之前要先注册
spark.udf.register("ip2Long",(ip:String)=>{
//返回Long类型
})
spark.udf.register("gn" new UserDefineAggregateFunction(){
//重新八个方法
})