自定义函数分类
UDF 输入一行,输出一行
UDAF 输入多行,输出一行
UDTF 输入一样,输出多行
UDF
//导包
import org.apache.spark.sql.SparkSession
//编写代码
// 1.实例SparkSession
val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()
// 2.根据sparkSession获取SparkContext
val sc = spark.sparkContext
//3.读取数据并输出
val datas = spark.read.textFile("./data/udf/udf.txt")
//4.数据展示
datas.show()
//5.编写UDF将小写变成大写
spark.udf.register("smallToBig", (str: String) => str.toUpperCase())
//6.将RDD转换为DataFrame
val dataFrame = datas.toDF()
//7.注册表
dataFrame.createOrReplaceTempView("word")
//8.使用自定义函数查询 并输出
spark.sql("select value, smallToBig(value) from word").show()
UDAF
继承UserDefinedAggregateFunction方法重写说明
InputSchema: 输入数据的类型
bufferSchema: 产生中间结果的数据类型
dataType:最终返回的结果类型
dataeministic: 确保一致性,一般用true
initialize: 指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate: 计算最终的结果
//导包
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
//编写自定义UDAF
class MyUDAF extends UserDefinedAggregateFunction {
// 输入的数据类型的schema
override def inputSchema: StructType = {
StructType(StructField("input", LongType) :: Nil)
}
//缓冲去数据类型schema 就是转换字后的数据schema
override def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("total", LongType) :: Nil)
}
// 返回值数据类型
override def dataType: DataType = {
DoubleType
}
// 确定是否相同的输入会有相同的输出
override def deterministic: Boolean = true
// 初始化内部数据结构
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// 更新数据内部结构,区内计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 所有的金额
buffer(0) = buffer.getLong(0) + input.getLong(0)
// 一共多少条数据
buffer(1) = buffer.getLong(1) + 1
}
// 来字不同分区数据进行合并,全局合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算输出数据值
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
//编写测试代码
//1.实例SparkSession
val spark = SparkSession.builder().master("local[*]").appName("sql").getOrCreate()
//2.根据SparkSession获取SparkContext 获取上下文对象
val sc = spark.sparkContext
//3.使用SparkContext 读取数据
val dataRDD = spark.read.json("./data/udf/udaf.json")
//4.注册表
dataRDD.createOrReplaceTempView("word")
//5.注册 UDAF 函数
spark.udf.register("myavg", new MyUDAF)
//6.使用自定义UDAF函数
spark.sql("select myavg(salary) from word").show()