1.简单的
package com.llcc.sparkSql.MyTimeSort
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
object UDFDemo1 {
def main(args:Array[String]):Unit = {
val conf = new SparkConf().setAppName("aa")
val sc = new SparkContext(conf)
val hiveContext = new HiveContext(sc)
hiveContext.udf.register("strlen",(str:String) => {
if(str != null){
str.length()
}else{
0
}
})
hiveContext.sql("select strlen(category) from xtwy.worker" ).show()
}
}
2. 继承 UserDefinedAggregateFunction
package com.llcc.sparkSql.MyTimeSort
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types._
object UDFDemo extends UserDefinedAggregateFunction{
/**
* 定义输入数据的类型,因为这里我们只有一列数据,但是这里要求一个集合,所以要加上Nil
* 这里我们要计算的是hive中的salary字段
* @return
*/
override def inputSchema: StructType = StructType(
StructField("salary",DoubleType,true)::Nil
)
/**
* 定义缓存字段的名字和数据类型
* @return
*/
override def bufferSchema: StructType = StructType(
StructField("total",DoubleType,true)::
StructField("count",IntegerType,true)::Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
/**
* 对参与的值进行初始化
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0.0)
buffer.update(1,0)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val total = buffer.getDouble(0)
val count = buffer.getInt(1)
val currentSalary = input.getDouble(0)
buffer.update(0,total+currentSalary)
buffer.update(1,count+1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val total1 = buffer1.getDouble(0)
val count1 = buffer1.getInt(1)
val total2 = buffer2.getDouble(0)
val count2 = buffer2.getInt(1)
buffer1.update(0,total1+total2)
buffer1.update(1,count1+count2)
}
override def evaluate(buffer: Row): Any = {
val total = buffer.getDouble(0)
val count = buffer.getInt(1)
total/count
}
def main(args:Array[String]):Unit = {
val conf = new SparkConf().setAppName("aa")
val sc = new SparkContext(conf)
val hiveContext = new HiveContext(sc)
hiveContext.udf.register("salary_avg",UDFDemo)
hiveContext.sql("select salary_avg(salary) from xtwy.worker" ).show()
}
}
原始数据
求薪水的平均值,可以看到是正确的