SparkSQL 用户自定义函数(UDF、UDAF、开窗)

UDF函数

通过spark.udf.register("funcName", func) 来进行注册

使用:select funcName(name) from people 来直接使用

UDAF函数

弱类型

需要继承UserDefineAggregateFunction并实现相关方法

使用:同样是注册一个udf函数

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/*
{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}
求平均工资
 */

class AverageSal extends UserDefinedAggregateFunction{
  // 输入数据
  override def inputSchema: StructType = StructType(StructField("salary", LongType) :: Nil)

  // 每一个分区中的 共享变量
  override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", IntegerType) :: Nil)

  // 表示UDAF的输出类型
  override def dataType: DataType = DoubleType

  // 表示如果有相同的输入是否存在相同的输出,如果是则true
  override def deterministic: Boolean = true

  // 初始化每个分区中的 共享变量
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L  // 就是sum
    buffer(1) = 0   // 就是count
  }

  // 每一个分区中的每一条数据  聚合的时候需要调用该方法
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // 获取这一行中的工资,然后将工资加入到sum    buffer(0) = buffer.getLong(0) + input.getLong(0)
    // 将工资的个数加1
    buffer(1) = buffer.getInt(1) + 1
  }

  // 将每一个分区的输出合并,形成最后的数据
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 合并总的工资
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    // 合并总的工资个数
    buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1)
  }

  // 给出计算结果
  override def evaluate(buffer: Row): Any = {
    // 取出总的工资 / 总工资个数
    buffer.getLong(0).toDouble / buffer.getInt(1)
  }
}

object AverageSal {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setAppName("udaf").setMaster("local[*]")
    val spark = SparkSession
      .builder()
      .config(sparkConf)
      .getOrCreate()
    val employee = spark.read.json("employee.json")

    employee.createOrReplaceTempView("employee")

    spark.udf.register("average", new AverageSal)

    spark.sql("select average(salary) from employee").show()

    spark.stop()
  }
}
强类型
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}

case class Employee(name: String, salary: Long)

case class Aver(var sum: Long, var count: Int)

class Average extends Aggregator[Employee, Aver, Double] {

  // 初始化方法 初始化每一个分区中的 共享变量
  override def zero: Aver = Aver(0L, 0)

  // 每一个分区中的每一条数据聚合的时候需要调用该方法
  override def reduce(b: Aver, a: Employee): Aver = {
    b.sum = b.sum + a.salary
    b.count = b.count + 1
    b
  }

  // 将每一个分区的输出 合并 形成最后的数据
  override def merge(b1: Aver, b2: Aver): Aver = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }

  // 给出计算结果
  override def finish(reduction: Aver): Double = {
    reduction.sum.toDouble / reduction.count
  }

  // 主要用于对共享变量进行编码
  override def bufferEncoder: Encoder[Aver] = Encoders.product

  // 主要用于将输出进行编码
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

}

object Average{

  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf().setAppName("udaf").setMaster("local[*]")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()

    import spark.implicits._

    val employee = spark.read.json("D:\\JetBrains\\workspace\\sparkcore\\sparksql\\src\\main\\resources\\employee.json").as[Employee]

    val aver = new Average().toColumn.name("average")

    employee.select(aver).show()

    spark.stop()
  }

}

开窗函数

rank()跳跃排序,有两个第二名时后边跟着的是第四名
dense_rank() 连续排序,有两个第二名时仍然跟着第三名
over()开窗函数:
       在使用聚合函数后,会将多行变成一行,而开窗函数是将一行变成多行;
       并且在使用聚合函数后,如果要显示其他的列必须将列加入到group by中,
       而使用开窗函数后,可以不使用group by,直接将所有信息显示出来。
        开窗函数适用于在每一行的最后一列添加聚合函数的结果。
常用开窗函数:
   1.为每条数据显示聚合信息.(聚合函数() over())
   2.为每条数据提供分组的聚合函数结果(聚合函数() over(partition by 字段) as 别名) 
         --按照字段分组,分组后进行计算
   3.与排名函数一起使用(row number() over(order by 字段) as 别名)
常用分析函数:(最常用的应该是1.2.3 的排序)
   1、row_number() over(partition by ... order by ...)
   2、rank() over(partition by ... order by ...)
   3、dense_rank() over(partition by ... order by ...)
   4、count() over(partition by ... order by ...)
   5、max() over(partition by ... order by ...)
   6、min() over(partition by ... order by ...)
   7、sum() over(partition by ... order by ...)
   8、avg() over(partition by ... order by ...)
   9、first_value() over(partition by ... order by ...)
   10、last_value() over(partition by ... order by ...)
   11、lag() over(partition by ... order by ...)
   12、lead() over(partition by ... order by ...)
lag 和lead 可以 获取结果集中,按一定排序所排列的当前行的上下相邻若干offset 的某个行的某个列(不用结果集的自关联);
lag ,lead 分别是向前,向后;
lag 和lead 有三个参数,第一个参数是列名,第二个参数是偏移的offset,第三个参数是 超出记录窗口时的默认值

猜你喜欢

转载自blog.csdn.net/liangzelei/article/details/80608302