自定义UDF(一进一出)
需求:为查询出的每个name前加上Hi
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[*]")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
val df: DataFrame = spark.read.json("D:\\develop\\workspace\\bigdata2021\\spark2021\\input\\test.json")
df.createOrReplaceTempView("user")
spark.udf.register("sayHi",(name:String) => {
"Hi:" + name})
spark.sql("select sayHi(name),age from user").show()
spark.stop()
}
自定义UDAF(多进一出)
需求:实现求平均值的功能
自定义弱类型UDAF(适用于Spark sql风格的查询)
object SparkSql04_udaf {
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[*]")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
val df: DataFrame = spark.read.json("D:\\develop\\workspace\\bigdata2021\\spark2021\\input\\test.json")
df.createOrReplaceTempView("user")
val myAvg = new MyAvg
spark.udf.register("myAvg", myAvg)
spark.sql("select myAvg(age) from user").show()
spark.stop()
}
}
class MyAvg extends UserDefinedAggregateFunction {
override def inputSchema: StructType = {
StructType(Array(StructField("age", IntegerType)))
}
override def bufferSchema: StructType = {
StructType(Array(StructField("age", LongType), StructField("count", LongType)))
}
override def dataType: DataType = {
DoubleType
}
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getInt(0)
buffer(1) = buffer.getLong(1) + 1
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = {
buffer.getLong(0) * 1.0 / buffer.getLong(1)
}
}
自定义强类型UDAF(适用于DSL风格的查询)
object SparkSql05_udaf {
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[*]")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
val df: DataFrame = spark.read.json("D:\\develop\\workspace\\bigdata2021\\spark2021\\input\\test.json")
val ds: Dataset[User] = df.as[User]
val avg2 = new MyAvg2
val col: TypedColumn[User, Double] = avg2.toColumn
ds.select(col).show()
spark.stop()
}
}
case class User(name:String, age: Long, gender: String)
case class MyBuffer(var ageSum: Long, var count: Long)
class MyAvg2 extends Aggregator[User, MyBuffer, Double] {
override def zero: MyBuffer = {
new MyBuffer(0L, 0L)
}
override def reduce(b: MyBuffer, a: User): MyBuffer = {
b.ageSum += a.age
b.count += 1
b
}
override def merge(b1: MyBuffer, b2: MyBuffer): MyBuffer = {
b1.ageSum += b2.ageSum
b1.count += b2.count
b1
}
override def finish(reduction: MyBuffer): Double = {
reduction.ageSum*1.0 / reduction.count
}
override def bufferEncoder: Encoder[MyBuffer] = {
Encoders.product
}
override def outputEncoder: Encoder[Double] = {
Encoders.scalaDouble
}
}