Spark SQL custom UDF|UDAF

Custom UDF (one in and one out)

Requirement: add Hi before each name in the query

def main(args: Array[String]): Unit = {
    
    
    val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[*]")
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    val df: DataFrame = spark.read.json("D:\\develop\\workspace\\bigdata2021\\spark2021\\input\\test.json")
    // 创建临时视图
    df.createOrReplaceTempView("user")

    // 注册并定义udf函数
    spark.udf.register("sayHi",(name:String) => {
    
    "Hi:" + name})

    //查询
    spark.sql("select sayHi(name),age from user").show()
    spark.stop()
}

Custom UDAF (multiple in and one out)

Requirements: realize the function of averaging

Custom weak type UDAF (applicable to Spark sql style query)

object SparkSql04_udaf {
    
    
  def main(args: Array[String]): Unit = {
    
    
    val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[*]")
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    val df: DataFrame = spark.read.json("D:\\develop\\workspace\\bigdata2021\\spark2021\\input\\test.json")
    // 创建临时视图
    df.createOrReplaceTempView("user")

    // 查询平均年龄
    // spark.sql("select avg(age) from user").show()

    // 创建自定义UDAF函数对象
    val myAvg = new MyAvg
    // 注册自定义UDAF函数
    spark.udf.register("myAvg", myAvg)
    spark.sql("select myAvg(age) from user").show()

    spark.stop()
  }
}

// 自定义UDAF
class MyAvg extends UserDefinedAggregateFunction {
    
    
  // 输入数据类型
  override def inputSchema: StructType = {
    
    
    // 输入可能有多个参数
    StructType(Array(StructField("age", IntegerType)))
  }

  // 临时存放结果的缓存,其中的数据类型
  override def bufferSchema: StructType = {
    
    
    StructType(Array(StructField("age", LongType), StructField("count", LongType)))
  }

  // 函数返回值类型
  override def dataType: DataType = {
    
    
    DoubleType
  }

  // 函数是否稳定:同样的输入,是否能一直返回相同的输出
  override def deterministic: Boolean = true

  // 初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    
    
    // 年龄总和
    buffer(0) = 0L
    // 年龄个数
    buffer(1) = 0L
  }

  // 更新缓冲区数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
    
    if (!input.isNullAt(0)) {
    
    
      // 输入类型是IntegerType
      buffer(0) = buffer.getLong(0) + input.getInt(0)
      buffer(1) = buffer.getLong(1) + 1
    }

  }

  // 合并分区间数据
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    
    
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 计算最终结果
  override def evaluate(buffer: Row): Any = {
    
    
    buffer.getLong(0) * 1.0 / buffer.getLong(1)
  }
}

Custom strongly typed UDAF (suitable for DSL style queries)

object SparkSql05_udaf {
    
    
  def main(args: Array[String]): Unit = {
    
    
    val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[*]")
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    // df转ds需要导入
    import spark.implicits._
    val df: DataFrame = spark.read.json("D:\\develop\\workspace\\bigdata2021\\spark2021\\input\\test.json")
    // df转ds
    val ds: Dataset[User] = df.as[User]
    // 创建聚合函数
    val avg2 = new MyAvg2
    // 将聚合函数转化为查询列
    val col: TypedColumn[User, Double] = avg2.toColumn
    // DSL风格查询
    // 将查询到的每一行作为参数传递到自定义函数中,返回对应的值
    ds.select(col).show()


    spark.stop()
  }
}

// 定义样例类,存储DataSet中的对象
// json加载数据会将int类型的数据加载为bigInt,需要用Long解析
case class User(name:String, age: Long, gender: String)

// 自定义缓冲区数据类型
case class MyBuffer(var ageSum: Long, var count: Long)

// 自定义强类型UDAF
/**
  * @tparam 数据的输入类型
  * @tparam BUF 缓冲区的类型
  * @tparam OUT 数据输出类型
  */
class MyAvg2 extends Aggregator[User, MyBuffer, Double] {
    
    
  // 初始化内存缓冲区
  override def zero: MyBuffer = {
    
    
    new MyBuffer(0L, 0L)
  }

  // 分区内聚合
  override def reduce(b: MyBuffer, a: User): MyBuffer = {
    
    
    b.ageSum += a.age
    b.count += 1
    b
  }

  // 分区间合并
  override def merge(b1: MyBuffer, b2: MyBuffer): MyBuffer = {
    
    
    b1.ageSum += b2.ageSum
    b1.count += b2.count
    b1
  }

  // 最终的结果
  override def finish(reduction: MyBuffer): Double = {
    
    
    reduction.ageSum*1.0 / reduction.count
  }

  // 缓冲区的编码类型
  override def bufferEncoder: Encoder[MyBuffer] = {
    
    
    // 自定义的类型,编码格式统一指定位product
    Encoders.product
  }

  // 输出的编码类型
  override def outputEncoder: Encoder[Double] = {
    
    
    Encoders.scalaDouble
  }
}

Guess you like

Origin blog.csdn.net/FlatTiger/article/details/115251249