Weakly typed and strongly typed custom UDAF functions


Weak type: expired in 3.x and available in 2.x
Strong type: 3.x, 2.x not available

Use the built-in avg function

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{
    
    DataFrame, SparkSession}

object UserDefinedUDAF {
    
    

  def main(args: Array[String]): Unit = {
    
    

    val spark: SparkSession = SparkSession.builder().appName("test").master("local[4]").getOrCreate()
    import spark.implicits._

    val list = List(
      ("zhangsan",20,"北京"),
      ("sd",30,"深圳"),
      ("asd",40,"北京"),
      ("asd",50,"深圳"),
      ("asdad",60,"深圳"),
      ("gfds",70,"北京"),
      ("dfg",60,"深圳"),
      ("erw",80,"上海"),
      ("asd",18,"广州"),
      ("sdassws",20,"广州"),
    )

    val rdd: RDD[(String, Int, String)] = spark.sparkContext.parallelize(list, 2)
    val df: DataFrame = rdd.toDF("name", "age", "region")
    df.createOrReplaceTempView("person")
    spark.sql(
      """
        |select
        |region,
        |avg(age)
        |from person group by region
        |""".stripMargin).show()
  }

}

result
Insert image description here

Weakly typed custom UDAF function (AVG)

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{
    
    MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{
    
    DataType, DoubleType, IntegerType, StructType}

/**
 * 自定义弱类型UDAF函数
 *     1.创建class继承
 */
class WeakAvgUDAF extends UserDefinedAggregateFunction{
    
    


  /**
   * 指定UDAF函数的参数类型【自定义avg函数,针对的参数是age,类型是Int类型】
   * @return
   */
  override def inputSchema: StructType = {
    
    
    new StructType()
      .add("input", IntegerType)
  }

  /**
   * 指定中间变量的类型【求一组区域的平均值,需要统计总年龄和人的个数】(因为最后要年龄除以人数才是平均年龄)
   * @return
   */
  override def bufferSchema: StructType = {
    
    
    new StructType()
      .add("sum", IntegerType)
      .add("count", IntegerType)

  }

  /**
   * 指定UDAF最终计算结果类型
   * @return
   */
  override def dataType: DataType = DoubleType

  /**
   * 一致性的执行
   * @return
   */
  override def deterministic: Boolean = true

  /**
   * 指定中间变量的初始化[sum=0,count=0]
   * @param buffer
   */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    
    
    //sum = 0
    buffer(0) = 0
    //count = 0
    buffer(1) = 0
  }

  /**
   * 类似combiner操作,针对每个组单个age值进行计算
   * @param buffer  中间变量的封装[sum,count]
   * @param input   组中一个值(age)
   */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
    
    buffer(0) = buffer.getAs[Int](0) + input.getAs[Int](0)
    buffer(1) = buffer.getAs[Int](1) + 1
  }

  /**
   *
   * @param buffer1 中间变量的封装[sum,count]
   * @param buffer2 combine的结果[sum,count]
   */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    
    
    //sum = sum + combiner_sum
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
    //count = count + combiner_count
    buffer1(1) = buffer1.getAs[Int](1) + buffer2.getAs[Int](1)
  }

  /**
   * 计算最终结果
   * @param buffer [中间变量封装 sum,count]
   * @return
   */
  override def evaluate(buffer: Row): Any = {
    
    
    buffer.getAs[Int](0).toDouble / buffer.getAs[Int](1)
  }
}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{
    
    DataFrame, SparkSession}

object UserDefinedUDAF {
    
    

  def main(args: Array[String]): Unit = {
    
    

    val spark: SparkSession = SparkSession.builder().appName("test").master("local[4]").getOrCreate()
    import spark.implicits._

    val list = List(
      ("zhangsan",20,"北京"),
      ("sd",30,"深圳"),
      ("asd",40,"北京"),
      ("asd",50,"深圳"),
      ("asdad",60,"深圳"),
      ("gfds",70,"北京"),
      ("dfg",60,"深圳"),
      ("erw",80,"上海"),
      ("asd",18,"广州"),
      ("sdassws",20,"广州"),
    )

    val rdd: RDD[(String, Int, String)] = spark.sparkContext.parallelize(list, 2)
    val df: DataFrame = rdd.toDF("name", "age", "region")
    df.createOrReplaceTempView("person")
    spark.udf.register("myavg",new WeakAvgUDAF)
    spark.sql(
      """
        |select
        |region,
        |myavg(age)
        |from person group by region
        |""".stripMargin).show()
  }

}

Insert image description here

Strongly typed custom UDAF function (AVG)

import org.apache.spark.sql.{
    
    Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator


/**
 * 自定义强类型UDAF函数
 * 1.定义class继承Aggregator[IN,BUFF,OUT]
 *     IN:代表UDAF函数参数类型
 *     BUFF:代表计算过程中中间变量类型
 *     OUT:最终计算结果类型
 * 2.重写重抽象方法
 * 强类型自定义UDAF函数的使用
 *   1.创建自定义UDAF对象 : val obj = new xxx
 *   2,导入转换方法 import org.apache.spark.sql.function._
 *   3.转换:val function = udaf(obj)
 *   4.注册 spark.udf.register(函数名,function)
 */
case class AvgBuff(sum:Int,count:Int)
class StrongAvgUDAF extends Aggregator[Int,AvgBuff,Double]{
    
    

  /**
   * 初始化中间变量值
   * @return
   */
  override def zero: AvgBuff = AvgBuff(0,0)

  /**
   * combiner计算
   * @param buff  中间结果
   * @param age   udaf参数
   * @return   返回累加之后的中间结果
   */
  override def reduce(buff: AvgBuff, age: Int): AvgBuff = AvgBuff(buff.sum+age,buff.count+1);

  /**
   * reducer聚合
   * @param b1 中间结果
   * @param b2 combiner聚合结果
   * @return  返回累加之后的中间
   */
  override def merge(b1: AvgBuff, b2: AvgBuff): AvgBuff = AvgBuff(b1.sum + b2.sum,b1.count+b2.count)

  /**
   * 计算最终结果
   * @param reduction
   * @return
   */
  override def finish(buff: AvgBuff): Double = buff.sum.toDouble / buff.count

  /**
   * 指定中间结果序列化
   * @return
   */
  override def bufferEncoder: Encoder[AvgBuff] = Encoders.product[AvgBuff]

  /**
   * 指定最终序列化类型
   * @return
   */
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{
    
    DataFrame, SparkSession}

object UserDefinedUDAF {
    
    

  def main(args: Array[String]): Unit = {
    
    

    val spark: SparkSession = SparkSession.builder().appName("test").master("local[4]").getOrCreate()
    import spark.implicits._

    val list = List(
      ("zhangsan",20,"北京"),
      ("sd",30,"深圳"),
      ("asd",40,"北京"),
      ("asd",50,"深圳"),
      ("asdad",60,"深圳"),
      ("gfds",70,"北京"),
      ("dfg",60,"深圳"),
      ("erw",80,"上海"),
      ("asd",18,"广州"),
      ("sdassws",20,"广州"),
    )

    val rdd: RDD[(String, Int, String)] = spark.sparkContext.parallelize(list, 2)
    val df: DataFrame = rdd.toDF("name", "age", "region")
    df.createOrReplaceTempView("person")
    //TODO 弱类型的注册
    spark.udf.register("myavg",new WeakAvgUDAF)
    //TODO 强类型的注册
    import org.apache.spark.sql.functions._
    spark.udf.register("myavg2",udaf(new StrongAvgUDAF))
    spark.sql(
      """
        |select
        |region,
        |myavg2(age)
        |from person group by region
        |""".stripMargin).show()

  }

}

Insert image description here

Guess you like

Origin blog.csdn.net/qq_46548855/article/details/134403970