自定义函数的分类
UDF:输入一参数,返回一个参数
UDTF:输入一参数,返回多个参数(hive中存在,sparkSQL中没有,因为spark中用flatMap即可实现该功能)
UDAF 输入多个参数,返回一个参数 aggregate(聚合) count、sum这些是sparkSQL自带的聚合函数,但是复杂的业务,要自己定义。
spark自定义函数的步骤
定义一个类,该类必须继承UserDefinedAggregateFunction类,并且实现该抽象类的8个方法。
注册函数,使用SparkSession对象来将该类注册成函数,函数名可以随便起。
UserDefinedAggregateFunction抽象类的抽象方法
class GeoMean extends UserDefinedAggregateFunction {
//输入数据的类型
override def inputSchema: StructType = StructType(List(
StructField("value", DoubleType)
))
//产生中间结果的数据类型
override def bufferSchema: StructType = StructType(List(
//相乘之后返回的积
StructField("product", DoubleType),
//参与运算数字的个数
StructField("counts", LongType)
/**
* 此处有几个中间数据,
* 就定义几个数据类型,
* 并且"指定初始值"、"update"
* 中数据的位置要与"merge"
* 此处的位置顺序保持一致。
*/
))
//最终返回的结果类型
override def dataType: DataType = DoubleType
//确保一致性 一般用true
override def deterministic: Boolean = true
//指定初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//相乘的初始值
buffer(0) = 1.0
//参与运算数字的个数的初始值
buffer(1) = 0L
}
//每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//每有一个数字参与运算就进行相乘(包含中间结果)
buffer(0) = buffer.getDouble(0) * input.getDouble(0)
//参与运算数据的个数也有更新
buffer(1) = buffer.getLong(1) + 1L
}
//全局聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//每个分区计算的结果进行相乘
buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
//每个分区参与预算的中间结果进行相加
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终的结果
override def evaluate(buffer: Row): Double = {
math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))
}
}
注:此函数的功能是求几何平均数。
聚会代码
import java.lang.Long
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{StructField, _}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
object UdafTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.master("local[*]")
.getOrCreate()
//添加数据
val range: Dataset[Long] = spark.range(1, 11)
//创建自定义函数类对象
val geomean = new GeoMean
/**
* SparkSQL风格
*/
//注册函数,"gm"函数别名,"geomean"自定义函数类
spark.udf.register("gm", geomean)
//将range这个Dataset[Long]注册成视图
range.createTempView("v_range")
val result = spark.sql("SELECT gm(id) result FROM v_range")
/**
* DSL风格
*/
import spark.implicits._
val result1 = range.agg(geomean($"id").as("geomean"))
result1.show()
spark.stop()
}
}