Spark Sql's UDF and UDAF functions

Spark Sql provides a wealth of built-in functions for ape friends to use. Why do you need user-defined functions? The actual business scenario may be very complicated, and the built-in functions cannot hold, so spark sql provides an extensible built-in function interface: buddy, your business is too perverted, I can't satisfy you, and define a sql function according to my specifications. , how to toss how to toss!

Here, let's take Scala to implement a simple hello world-level sample as an example to experience the use of udf and udaf.

question

will be the following array:

val bigData = Array("Spark","Hadoop","Flink","Spark","Hadoop","Flink",
"Spark","Hadoop","Flink","Spark","Hadoop","Flink")

The characters in the group are aggregated and the length of each character and the number of character occurrences are calculated. The normal result 
is as follows:

+------+-----+------+
|  name|count|length|
+------+-----+------+
| Spark|    4|     5|
cool|    4 |     5 |
|Hadoop|    4|     6|
+------+-----+------+

Note: The character 'spark' has a length of 5 and appears 4 times in total.

analyze

  • A custom sql function that finds the length of a string is the 
    same as the ordinary function in scala, except that the former needs to be registered in the sqlContext first.
  • After a custom aggregation function 
    is grouped by string name, call the custom aggregation function to achieve accumulation. 
    Ah, so abstract, just look at the code!

code

package com.hand.datasafe

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, SparkSession}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType

/**
  * Spark SQL UDAF:user defined aggregation function
  * UDF: The input of the function is a specific data record, and the implementation is an ordinary scala function - it just needs to be registered
  * UDAF: User-defined aggregate function, the function itself acts on the data set, and can perform custom operations on the basis of specific operations
  */
object SparkSQLUDF {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("datasafe").master("local").getOrCreate()

    val bigData = Array("Spark", "Hadoop", "Flink", "Spark", "Hadoop", "Flink", "Spark", "Hadoop", "Flink", "Spark", "Hadoop", "Flink")
    val bigDataRDD = spark.sparkContext.parallelize (bigData)

    val bigDataRowRDD: RDD [Row] = bigDataRDD.map (line => Row (line))
    val structType = StructType(Array(StructField("name", StringType, true)))
    val bigDataDF = spark.createDataFrame(bigDataRowRDD, structType)
    bigDataDF.printSchema()
    bigDataDF.createTempView("bigDataTable")

    /*
     * Register UDF through saprk, in scala2.1.x version UDF function can accept up to 22 input parameters
     */
    spark.udf.register("computeLength", (input: String) => input.length)
    spark.sql("select name,computeLength(name)  as length from bigDataTable").show

    //while(true){}

    spark.udf.register("wordCount", new MyUDAF)
    spark.sql("select name,wordCount(name) as count,computeLength(name) as length from bigDataTable group by name ").show
    spark.sql("select name,wordCount(name) as count,computeLength(name) as length from bigDataTable group by name ").printSchema()

  }
}
package com.hand.datasafe

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * User-defined function
  */
class MyUDAF extends UserDefinedAggregateFunction
{
  /**
    * Specify the type of specific input data
    * Optional from section name: Users can choose names to identify the input arguments - this can be "name", or any other string
    */
  override def inputSchema:StructType = StructType(Array(StructField("name",StringType,true)))

  /**
    * The intermediate result type of the data to be processed when performing the aggregation operation
    */
  override def bufferSchema:StructType = StructType(Array(StructField("count",IntegerType,true)))

  /**
    * return type
    */
  override def dataType:DataType = IntegerType

  /**
    * whether given the same input,
    * always return the same output
    * true: yes
    */
  override def deterministic:Boolean = true

  /**
    * Initializes the given aggregation buffer
    */
  override def initialize(buffer:MutableAggregationBuffer):Unit = {buffer(0)=0}

  /**
    * When performing aggregation, how to calculate the grouped aggregation whenever a new value comes in
    * Local aggregation operation, equivalent to Combiner in Hadoop MapReduce model
    */
  override def update(buffer:MutableAggregationBuffer,input:Row):Unit={
    buffer(0) = buffer.getInt(0)+1
  }

  /**
    * Finally, a global-level merge operation needs to be performed after the local reduce on the distributed nodes is completed
    */
  override def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
    buffer1( 0 ) = buffer1.getInt( 0 )+buffer2.getInt( 0 )
  }

  /**
    * Return the final calculation result of UDAF
    */
  override def evaluate(buffer:Row):Any = buffer.getInt(0)
}

Summarize

    • Call spark to upgrade udaf implementation 
      In order to implement a sql aggregation function by myself, I need to inherit UserDefinedAggregateFunction and implement 8 abstract methods! 8 ways! what's a disaster! However, in order to complete the aggregation class (a = aggregation) function in sql that meets a specific business scenario, you need udaf. 
      How to understand MutableAggregationBuffer? It is to store intermediate results, and aggregation means the accumulation of multiple records and other operations.

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324771580&siteId=291194637