大数据学习之路90-sparkSQL自定义聚合函数UDAF

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_37050372/article/details/82981177

什么是UDAF?就是输入N行得到一个结果,属于聚合类的。

接下来我们就写一个求几何平均数的一个自定义聚合函数的例子

我们从开头写起,先来看看需要进行计算的数如何产生:

package com.test.SparkSQL

import java.lang

import org.apache.spark.sql.{Dataset, SparkSession}

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("UDAFDemo")
      .master("local[*]")
      .getOrCreate()
    val ds: Dataset[lang.Long] = spark.range(1,10)
    ds.show()
  }
}

生成结果:

接下来我们使用自定义聚合函数计算几何平均数:

package com.test.SparkSQL

import java.lang

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession, types}

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("UDAFDemo")
      .master("local[*]")
      .getOrCreate()
    val ds: Dataset[lang.Long] = spark.range(1,10)
    //ds.show()
    ds.createTempView("v_num")
    val gm = new GeometriMean
    spark.udf.register("gm",gm)
    spark.sql("select gm(id) as gm from v_num").show()
  }
}

class GeometriMean extends UserDefinedAggregateFunction{
  //定义输入数据的类型
  override def inputSchema: StructType = StructType(List(StructField("value",DoubleType)))
  //定义存储聚合运算时产生的中间数据结果的类型
  override def bufferSchema: StructType = StructType(
    List(
      StructField("count",LongType),
      StructField("product",DoubleType)
    )
  )
  //表名了UDAF函数的返回值类型
  override def dataType: DataType = DoubleType
  //用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
  override def deterministic: Boolean = true
  //对聚合运算中间结果的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 1.0
  }
  //每处理一条数据都要执行update,相当于局部计算
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
       buffer(0) = buffer.getAs[Long](0)+1
       buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
  }
  //负责合并两个聚合运算的buffer,再将其存储到MutableAggregationBuffer,相当于全局计算
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
    buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
  }
  //完成对聚合Buffer值的运算,得到最后的结果
  override def evaluate(buffer: Row): Any = {
    math.pow(buffer.getDouble(1),1.toDouble/buffer.getLong(0))
  }
}

运行结果:

猜你喜欢

转载自blog.csdn.net/qq_37050372/article/details/82981177
今日推荐