自定义UDAF(多对一)

package day01

import org.apache.spark.sql.{Row, types}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
  * 自定义一个聚合方法
  * 首先要定义一个类继承UserDefinedAggregateFunction
  * 重写8个方法
  *
  */
class GeometricMean extends  UserDefinedAggregateFunction{
  //UDAF与DataFrame列有关的输入样式,StructField的名字并没有特别要求,完全可以认为是两个内部结构的列名站位符
  //至于UDAF具体要操作DataFrame的那个列,取决于调用者,但前提是数据类型必须符合事先的设置,如这里的Double
  override def inputSchema: StructType = StructType(List(StructField("value",DoubleType)))
  //定义存储聚合运算时产生的中间数据结果的Schema
  override def bufferSchema: StructType = StructType(List(
   //参与运算的个数
    StructField("count",LongType),
    //参与运算的乘积
    StructField("product",DoubleType)
  ))
  //标明了UDAF函数的返回值类型
  override def dataType: DataType = DoubleType
  //用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
  override def deterministic: Boolean = true
  //对聚合运算中间结果的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
  //参与运算数字的个数
    buffer(0)=0L
    //参与相乘的值
    buffer(1)=1.0
  }
  //每处理一条数据都要执行update
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
   //局部计算
    buffer(0) =buffer.getAs[Long](0) +1
    buffer(1) =buffer.getAs[Double](1) * input.getAs[Double](0)
  }
  //负责合并两个聚合运算的buffer,再将其存储到MutableAggregationBuffer中
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
   //全局  累加    累乘
    buffer1(0)=buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
    buffer1(1)=buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
  }
  //完成对聚合Buffer值得运算,得到最后的结果
  override def evaluate(buffer: Row): Any = {
    math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
  }
}
object GeometricMean{
  def main(args: Array[String]): Unit = {
   // val r =Math.pow(1*2*3*4*5*6*7*8*9,1.toDouble/9)
   val r =Math.pow(3,1.toDouble/2)
    println(r)
  }
}
package day01

import java.lang

import org.apache.spark.sql.{Dataset, SparkSession}


object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("UDAFDemo")
      .master("local[*]")
      .getOrCreate()
     //df.show()列叫id
    val df: Dataset[lang.Long] = spark.range(1,10)

    val gm = new GeometricMean
      //写sql需要注册视图
    // df.createTempView("v_num")
    spark.udf.register("gm",gm)
 //   spark.sql("SELECT gm(id) as gm from v_num").show()

    //不用视图来弄,直接使用算子
   // df.select(expr("gm(id) as GeometricMean")).show()
   // df.groupBy().agg(gm(col("id")).as("GeometricMean")).show
  }
}

SparkSQL的自定义函数
UDF 调用函数式输入一行,返回一个值, 1->1 substring
UDAF 调用函数时输入N行,返回一个值 N-> 1 count(*)

使用UDFs之前要先注册
spark.udf.register("ip2Long",(ip:String)=>{
    //返回Long类型
})

spark.udf.register("gn" new UserDefineAggregateFunction(){
    //重新八个方法
})

猜你喜欢

转载自blog.csdn.net/LJ2415/article/details/85016206