Spark之hive的UDF自定义函数

1.简单的

package com.llcc.sparkSql.MyTimeSort

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext

object UDFDemo1 {

  def main(args:Array[String]):Unit = {
    val conf = new SparkConf().setAppName("aa")
    val sc = new SparkContext(conf)
    val hiveContext = new HiveContext(sc)
    hiveContext.udf.register("strlen",(str:String) => {
      if(str != null){
        str.length()
      }else{
        0
      }
    })
    hiveContext.sql("select strlen(category) from xtwy.worker" ).show()

  }

}

这里写图片描述

2. 继承 UserDefinedAggregateFunction

package com.llcc.sparkSql.MyTimeSort

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types._

object UDFDemo extends UserDefinedAggregateFunction{

  /**
    * 定义输入数据的类型,因为这里我们只有一列数据,但是这里要求一个集合,所以要加上Nil
    * 这里我们要计算的是hive中的salary字段
    * @return
    */
  override def inputSchema: StructType = StructType(
    StructField("salary",DoubleType,true)::Nil
  )

  /**
    * 定义缓存字段的名字和数据类型
    * @return
    */
  override def bufferSchema: StructType = StructType(
    StructField("total",DoubleType,true)::
      StructField("count",IntegerType,true)::Nil
  )

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  /**
    * 对参与的值进行初始化
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0,0.0)
    buffer.update(1,0)
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val total = buffer.getDouble(0)
    val count = buffer.getInt(1)
    val currentSalary = input.getDouble(0)
    buffer.update(0,total+currentSalary)
    buffer.update(1,count+1)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val total1 = buffer1.getDouble(0)
    val count1 = buffer1.getInt(1)

    val total2 = buffer2.getDouble(0)
    val count2 = buffer2.getInt(1)

    buffer1.update(0,total1+total2)
    buffer1.update(1,count1+count2)

  }

  override def evaluate(buffer: Row): Any = {
    val total = buffer.getDouble(0)
    val count = buffer.getInt(1)
    total/count
  }

  def main(args:Array[String]):Unit = {
    val conf = new SparkConf().setAppName("aa")
    val sc = new SparkContext(conf)
    val hiveContext = new HiveContext(sc)
    hiveContext.udf.register("salary_avg",UDFDemo)
    hiveContext.sql("select salary_avg(salary) from xtwy.worker" ).show()

  }
}

原始数据

这里写图片描述

求薪水的平均值,可以看到是正确的

这里写图片描述

猜你喜欢

转载自blog.csdn.net/qq_21383435/article/details/80519068