Custom udf and udaf functions in spark

Custom function

Types of

	- UDF:一进一出
	- UDAF:多进一出

UDF

Process

  • 1. Custom udf function/class (class should pay attention to serialization)
  • 2. Register spark.udf.register("name", custom function/custom class_)
  • 3. Call the query method

Custom udf function and call

import org.apache.spark.sql.SparkSession
import org.junit.Test

/**
  * @ClassName: MyUDFdemo
  * @Description: 将员工中id不满8位的补齐
  * @Author: kele
  * @Date: 2021/2/1 20:56
  **/

/**
  * 1、自定义udf函数/类(类要注意需要序列化)
  * 2、注册spark.udf.register("名称",自定义的函数/自定义的类 _)
  * 3、调用查询方法
  */
class MyUDFdemo extends Serializable{
    
    

  @Test
  def emp_info={
    
    

    val spark = SparkSession.builder().master("local[4]").appName("UDFdemo").getOrCreate()

    import spark.implicits._    //rddtoDF的隐式转换

    val rdd1 = spark.sparkContext.parallelize(List(
      ("00123","zhangsan"),
      ("256","lisi"),
      ("0135","wangwu"),
      ("000368","qianqi"),
      ("00378","zhaoliu")
    ))

    val df = rdd1.toDF("id","name")

    /**
      * 方式一:通过sql的方式查询 自定义函数
      *
      */

//    df.createOrReplaceTempView("user")
//    spark.udf.register("fullId",fullUserId)
//    spark.sql("""select fullId(id) from user """).show()

        /**
          * 自定义类,需要序列化
          *
          */
    df.createOrReplaceTempView("user")
    spark.udf.register("fullId2",fullUserIdclass _)
    spark.sql("""select fullId2(id) from user """).show()

    /**
      * 方式二:selectExpr的方式查找
      */
    df.selectExpr("fullId2(id) id").show()

  }

  //自定义udf函数
  val fullUserId = (id : String)=>{
    
    
    s"${"0" *(8-id.length)}${id}"
  }

  //自定义udf类
  def fullUserIdclass(id:String) ={
    
    
    s"${"0" *(8-id.length)}${id}"
  }


}

OUT OF

UDAF weak type implementation

Overall process

  • 1. Inherit UserDefinedAggregateFunction (no generics)
  • 2. Rewriting method
    -1. Specify the type with the statistical list
    -2. Specify the type of the intermediate variable
    -3. Specify the return type of the function
    -4. Set the stability
    -5. Initialize the value of the intermediate variable
    -6. Find one The calculation process in the task
    -7, the calculation process between the partitions
    -8, the return value of the function
  • 3. Register spark.udf.register and bind a name to it

Custom UDAF weak type

 import org.apache.spark.sql.Row
 import org.apache.spark.sql.expressions.{
    
    MutableAggregationBuffer, UserDefinedAggregateFunction}
 import org.apache.spark.sql.types.{
    
    DataType, DoubleType, IntegerType, StructType}
 
 /**
   * @ClassName: MyUDAF
   * @Description:
   * @Author: kele
   * @Date: 2021/2/1 16:03
   **/
 class MyUDAF extends UserDefinedAggregateFunction{
    
    
 
   /**
     * 指定待统计的数据类型
     * @return 返回值类型是StructType类型,
     */
   override def inputSchema: StructType = new StructType().add("age",IntegerType)
 
   /**
     * 这里是求平均值,需要sum,和num,因此需要两个中间变量
     * 指定中间变量的类型,数据进入是是一个个进
     * @return
     */
   override def bufferSchema: StructType = new StructType().add("sum",IntegerType)
     .add("num",IntegerType)
   /**
     * 函数的返回类型
     * @return
     */
   override def dataType: DataType = DoubleType
 
   /**
     * 稳定性,同一组数据输入是否返回相同的值
     * @return
     */
   override def deterministic: Boolean = true
 
   /**
     * 初始化buffer的值
     * @param buffer
     */
   override def initialize(buffer: MutableAggregationBuffer): Unit = {
    
    
 
     buffer.update(0,0)
     buffer.update(1,0)
   }
 
   /**
     * 在一个task中的计算过程
     *   sum将age不断累加
     *   count+1
     * @param buffer
     * @param input
     */
   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
    
 
     buffer.update(0,buffer.getAs[Int](0)+input.getAs[Int](0))
     buffer.update(1,buffer.getAs[Int](1)+1)
   }
 
   /**
     * 分区间的计算方式
     * @param buffer1
     * @param buffer2
     */
   override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    
    
 
     buffer1.update(0,buffer1.getAs[Int](0)+buffer2.getAs[Int](0))
     buffer1.update(1,buffer1.getAs[Int](1)+buffer2.getAs[Int](1))
 
   }
 
   /**
     * 返回值
     * @param buffer
     * @return
     */
   override def evaluate(buffer: Row): Any = buffer.getAs[Int](0).toDouble/buffer.getAs[Int](1)
 }

Call process

  /**
    * 调用弱类型
    */
  @Test
  def avg_Age={
    
    

    val spark = SparkSession.builder()
                                    .master("local[4]")
                                        .appName("avg_age")
                                              .getOrCreate()

    val rdd = spark.sparkContext.parallelize(List(
      ("zhangsan",20,"开发部"),
      ("wanwu",25,"产品部"),
      ("aa",26,"开发部"),
      ("lisi",40,"开发部"),
      ("bb",30,"产品部"),
      ("cc",28,"产品部")
    ))

    import spark.implicits._

    val df = rdd.toDF("name","age","dept")


    df.createOrReplaceTempView("user")

    spark.udf.register("myavg",new MyUDAF)

    spark.sql(
      """
        |select myavg(age) from user group by dept
      """.stripMargin).show()
  }

UDAF strong typing process

  • 1. Custom class inherits Aggregator [type of statistical column, type of intermediate variable, type of output result]

  • 2. Rewriting method

    • 1. Initialize intermediate variables
    • 2. The statistical process in each task
    • 3. The calculation process between partitions
    • 4. Calculate the final result and return
    • 5. Encoding the type of intermediate variable, personally think it is to ensure the transmission of intermediate data.
      Note that the parent class of the sample class is product
  • 3. Register spark.udf.register (function name, udaf (custom udaf object))

      		- import org.apache.spark.sql.functions._         //必须调用该隐式转换,否则无法导入
    

Custom strong typing

 package com.atguigu.day05
 
 import org.apache.spark.sql.{
    
    Encoder, Encoders}
 import org.apache.spark.sql.expressions.Aggregator
 
 /**
   * @ClassName: MyUDAF2
   * @Description: 强类型自定义类,Aggregator可以自定义泛型[输入类型,中间变量,输出类型]
   * @Author: kele
   * @Date: 2021/2/1 16:05
   **/
 
 /**
   * 如果需要多个中间变量,可以考虑使用样例类
   *
   */
 
 case class InterVari(var sum:Int,var count:Int)
 class MyUDAFStrong extends Aggregator[Int,InterVari,Double]{
    
    
 
   /**
     * 初始化中间变量
     * @return
     */
   override def zero: InterVari = InterVari(0,0)
 
   /**
     * 每一个task中的统计过程
     * @param b
     * @param a
     * @return
     */
   override def reduce(b: InterVari, a: Int): InterVari = {
    
    
 
     b.sum = b.sum+a
     b.count = b.count+1
     b
   }
 
   /**
     * 分区间计算过程
     * @param b1
     * @param b2
     * @return
     */
   override def merge(b1: InterVari, b2: InterVari): InterVari = {
    
    
     b1.sum = b1.sum + b2.sum
     b1.count = b1.count + b2.count
     b1
   }
 
   /**
     * 最终结果返回
     * @param reduction
     * @return
     */
   override def finish(reduction: InterVari): Double = reduction.sum.toDouble/reduction.count
 
   /**
     * 编码中间变量的类型,个人认为是为了保证中间数据传输
     * @return  样例类的父类是product
     */
   override def bufferEncoder: Encoder[InterVari] = Encoders.product
 
   /**
     * 编码结果值的类型,个人认为是为了保证中间数据传输
     * @return
     */
   override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
 }

Call process

  /**
    * 使用弱类型
    */
  @Test
  def avg_Age={
    
    

    val spark = SparkSession.builder()
                                    .master("local[4]")
                                        .appName("avg_age")
                                              .getOrCreate()

    val rdd = spark.sparkContext.parallelize(List(
      ("zhangsan",20,"开发部"),
      ("wanwu",25,"产品部"),
      ("aa",26,"开发部"),
      ("lisi",40,"开发部"),
      ("bb",30,"产品部"),
      ("cc",28,"产品部")
    ))

    import spark.implicits._

    val df = rdd.toDF("name","age","dept")


    df.createOrReplaceTempView("user")

    spark.udf.register("myavg",new MyUDAF)

    spark.sql(
      """
        |select myavg(age) from user group by dept
      """.stripMargin).show()
  }

Guess you like

Origin blog.csdn.net/qq_38705144/article/details/113528888