SparkSQL之自定义函数UDF和UDAF

SparkSQL中有两种自定函数,在我们使用自带的函数时无法满足自己的需求时,可以使用自定义函数,SparkSQL中有两种自定义函数,一种是UDF,另一种是UDAF,和Hive 很类似,但是hive中还有UDTF,一进多出,但是sparkSQL中没有,这是因为spark中用 flatMap这个函数,可以实现和udtf相同的功能
UDF函数是针对的是一进一出
UDAF针对的是多进一出

udf很简单,只需要注册一下,然后写一个函数,就可以在sql查询中使用了

    df1.createTempView("user")
    //注册
    spark.udf.register("lengthStr",(str:String)=>str.length)//自定义函数
    //直接在sql中就可以使用啦
    val df2 = spark.sql("select lengthStr(name) from user")

udaf相对来说比较复杂一点,需要继承一个 UserDefinedAggregateFunction类,在重写其中的方法,自定义函数求平均值,详细的步骤在下面的代码中

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession, types}

object UDAFavg {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("avg").master("local").getOrCreate()
    val sc = spark.sparkContext
    val sqlContext = spark.sqlContext
    val files: RDD[String] = sc.textFile("D:\\read\\teacher.txt")
    val rowRDD: RDD[Row] = files.map(row => {
      val split = row.split(" ")
      Row(split(0), split(1),split(2).toLong)
    })
  /*  rowRDD.foreach(row =>{
      println(row.getString(0)+" "+row.getString(1)+row.get(2))
    })*/
    val structType = StructType(List(StructField("subject",StringType,true),StructField("tname",StringType,true),
      StructField("age",LongType,true)))
    val df1: DataFrame = spark.createDataFrame(rowRDD,structType)
    df1.createTempView("teacher")
    //注册函数, 自定义一个函数,实现求平均数
    spark.udf.register("TeacherAvg",new UDAFavg)
    //df1.show()
    spark.sql("select subject,TeacherAvg(age) as avgAGE from teacher group by subject ").show()

  }
}
//自定义UDAF函数
class UDAFavg extends UserDefinedAggregateFunction{
//输入数据类型,求平均值,所以数据类型是LongType(StructType中的类型)
  override def inputSchema: StructType = {
    StructType(List(StructField("age",LongType,true)))}
    //中间结果的类型,这里定义了两个中间的类型,因为在求平均值时,首先一个存总的和,一个计算个数,最后的结果是两者相除
  override def bufferSchema: StructType = {
    StructType(List(StructField("age",LongType),StructField("count",LongType)))}
    //输出返回类型
  override def dataType: DataType = {LongType}
  //是否数据同一性,一般都是true
  override def deterministic: Boolean = true
  //初始化定义两个中间值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
  	//类型要和上面定义的位置相对应
    buffer(0) = 0L  //初始化 总和
    buffer(1) = 0L  // 个数
     }
    //进行计算,
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {  //input是每次的输入Row类型
    buffer(1) = buffer.getAs[Long](1)+ 1  //个数 每次加1  
    buffer(0) = buffer.getAs[Long](0) + input.getLong(0) 
   // 把每个传的值进行累加
    }
    //有可能有多个分区,多个task ,总后把进行合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(1) = buffer1.getAs[Long](1)+ buffer2.getAs[Long](1)//多台机器中的count的值进行相加
    buffer1(0) = buffer1.getAs[Long](0)  + buffer2.getLong(0)
    }
    //返回的最终结果
  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Long](0) / buffer.getAs[Long](1)
  }
}

猜你喜欢

转载自blog.csdn.net/Lu_Xiao_Yue/article/details/83958391