Spark Learning Road (19) SparkSQL's custom function UDF

In Spark, custom functions in Hive are also supported. Custom functions can be roughly divided into three types:

  • UDF (User-Defined-Function), the most basic custom function, similar to to_char, to_date, etc.
  • UDAF (User-Defined Aggregation Funcation), user-defined aggregation function, similar to sum, avg, etc. used after group by
  • UDTF (User-Defined Table-Generating Functions), user-defined generating function, a bit like flatMap in stream

To customize a UDF function, you need to inherit the UserDefinedAggregateFunction class and implement 8 methods.

Example

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

object GetDistinctCityUDF extends UserDefinedAggregateFunction{
  /**
    * input data type
    * */
  override def inputSchema: StructType = StructType(
    StructField("status",StringType,true) :: Nil
  )
  /**
    * Cache field type
    * */
  override def bufferSchema: StructType = {
    StructType(
      Array(
        StructField("buffer_city_info",StringType,true)
      )
    )
  }
/**
  * output result type
  * */
  override def dataType: DataType = StringType
/**
  * Whether the input type and output type are the same
  * */
  override def deterministic: Boolean = true
/**
  * Initialize auxiliary fields
  * */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0,"")
  }
/**
  * Modify the value of the auxiliary field
  * */ 
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     // Get the last value 
    var last_str = buffer.getString( 0 )
     // Get the current value 
    val current_str = input.getString( 0 )
     // Determine whether the last value contains the current value 
    if (! last_str.contains(current_str)){
       // Determine whether it is the first value, if yes, go to if assignment, if not, add else 
      if (last_str.equals( "" )){
        last_str = current_str
      }else{
        last_str += "," + current_str
      }
    }
    buffer.update(0,last_str)

  }
/**
  * Merge partition results
  *buffer1 is the result on machine hadoop1
  *buffer2 is the result on machine Hadoop2
  * */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    var buf1 = buffer1.getString(0)
    val buf2 = buffer2.getString( 0 )
     // Append the data in buf2 but not in buf1 to buf1
     // The data in buf2 is divided according to 
    for (s <- buf2.split( " , " )) {
       if (! buf1.contains(s)){
         if (buf1.equals( "" )){
          buf1 = s
        }else{
          buf1 += s
        }
      }
    }
    buffer1.update(0,buf1)
  }
/**
  * Final calculation result
  * */
  override def evaluate(buffer: Row): Any = {
    buffer.getString(0)
  }
}

Register a custom UDF function as a temporary function

def main(args: Array[String]): Unit = {
    /**
      * The first step is to create a program entry
      */
    val conf = new SparkConf().setAppName("AralHotProductSpark")
    val sc = new SparkContext(conf)
    val hiveContext = new HiveContext(sc) 
  //register as a temporary function hiveContext.udf.register(
" get_distinct_city " ,GetDistinctCityUDF)   //register as a temporary function hiveContext.udf.register( " get_product_status " ,(str:String) => { var status = 0 for (s <- str.split( " , " )){ if (s.contains( " product_status " )){ status = s.split(":")(1).toInt } } }) }

 

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325281793&siteId=291194637