SparkSQL-自定函数

UDF

        一路输入,一路输出

练习需求:模拟获取字符串长度

        准备文件

def main(args: Array[String]): Unit = {
    //创建Session对象
    val spark = SparkSession
      .builder() //构建器
      .appName("sparkSQL") //序名称程
      .master("local[*]") //执行方式:本地
      .getOrCreate() //创建对象

    //读取数据
    val df: DataFrame = spark.read.json("file:///D:\\spark.test\\datas\\people.json")

    //方法1:自定义UDF并注册
    spark.udf.register("UDFlenth1",(x : String) => x.length)

    //方法2:提供一个函数 再注册 register[提供函数返回值类型,函数参数类型]
    spark.udf.register[Int,String]("UDFlenth2",lenths)
      //定义提供函数实现
    def lenths(x : String): Int ={
      x.length
    }

    //建立数据视图表
    df.createOrReplaceTempView("people")

    spark.sql("select UDFlenth1(name) from people").show()

    spark.stop()
  }

UDAF 

        多路输入,一路输出 

        类似于combineByKey,需要提供一个类继承UserDefinedAggregateFunction,实现抽象方法

练习需求:模拟avg()

object sparkSQL09 {
  def main(args: Array[String]): Unit = {
    //创建Session对象
    val spark = SparkSession
      .builder() //构建器
      .appName("sparkSQL") //序名称程
      .master("local[*]") //执行方式:本地
      .getOrCreate() //创建对象

    //读取数据
    val df: DataFrame = spark.read.json("file:///D:\\spark.test\\datas\\emo.json")
    //建立数据视图表
    df.createOrReplaceTempView("emp")
    //注册UDAF函数
    spark.udf.register("MyAvg",new MyUDAF)

    //使用
    spark.sql("select MyAvg(salary) from emp").show()


    spark.stop()
  }
}
class MyUDAF extends UserDefinedAggregateFunction{

  //输入数据的Schema信息
  override def inputSchema: StructType = StructType(
    List(StructField("salary",DoubleType,true))
  )
  //每一个分区中的共享变量 提供分区中聚合之后得到的结果集存储的位置
  override def bufferSchema: StructType =
    StructType(List(
      StructField("sum",DoubleType,true), //工资的总和
      StructField("count",DoubleType,true) //分区内工资累加的次数
    ))

  //返回值的数据类型,表示UDAF函数输出结果的输出类型
  override def dataType: DataType = DoubleType
  //如果有相同输入 是否有相同输出
  override def deterministic: Boolean = true //默认为true
  //对当前Buffer中属性进行初始化操作,对每个分区进行变量赋值操作
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //工资总和的赋值 sum
    buffer(0) = 0.0
    //工资累加次数的赋值 count
    buffer(1) = 0.0
  }
  //对分区内数据进行聚合操作
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(!input.isNullAt(0)){ //判断只要薪水不是空的
      //计算一行薪水的工资值
      buffer(0) = buffer.getDouble(0) + input.getDouble(0)
      //计算次数
      buffer(1) = buffer.getDouble(1) + 1

    }
  }
  //全局聚合 ,将分区内计算的数据再聚合在一起
  //buffer1 存的是最终全局聚合的数据值 buff2 是对应每个分区计算结果值
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //全局聚合总工资
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    //全局聚合总次数
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }
  //最终计算结果
  override def evaluate(buffer: Row): Double = {
    buffer.getDouble(0) / buffer.getDouble(1)
  }
}

猜你喜欢

转载自blog.csdn.net/dafsq/article/details/129638332