spark 窗口函数(Window)实战详解

项目github地址:bitcarmanlee easy-algorithm-interview-and-practice
经常有同学私信或留言询问相关问题,V号bitcarmanlee。github上star的同学,在我能力与时间允许范围内,尽可能帮大家解答相关问题,一起进步。

1.为什么需要窗口函数

在1.4以前,Spark SQL支持两种类型的函数用来计算单个的返回值。第一种是内置函数或者UDF函数,他们将单个行中的值作为输入,并且他们为每个输入行生成单个返回值。另外一种是聚合函数,典型的是SUM, MAX, AVG这种,是对一组行数据进行操作,并且为每个组计算一个返回值。

上面提到的两种函数,实际当中使用非常广泛,但是仍然存在大量无法单独使用这些类型的函数来表达的操作。最常见的一种场景就是,很多时候需要对一组行进行操作,而仍然为每个输入行返回一个值,上面的两种方法就无能为力。例如对于计算移动平均值,计算累计和或访问出现在当前行之前的行的值等,就显得非常困难。幸运的是,在1.4以后的版本,Spark SQL就提供了窗口函数来弥补上面的不足。

窗口函数的核心是“Frame”,或者我们直接称呼其为帧,帧就是一系列的多行数据,或者说许多分组。然后我们可以基本这些分组来满足上面普通函数无法完成的功能。为了看清楚其具体的应用,我们直接看例子。Talk is cheap, Show me the code.

2.构造数据集

为了方便测试,我们首先构造数据集

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions._

  def test() = {
    val sparkConf = new SparkConf().setMaster("local[2]")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()

    val data = Array(("lili", "ml", 90),
      ("lucy", "ml", 85),
      ("cherry", "ml", 80),
      ("terry", "ml", 85),
      ("tracy", "cs", 82),
      ("tony", "cs", 86),
      ("tom", "cs", 75))

    val schemas = Seq("name", "subject", "score")
    val df = spark.createDataFrame(data).toDF(schemas: _*)

    df.show()
 }

将上面的test方法本地run起来以后,输出如下

+------+-------+-----+
|  name|subject|score|
+------+-------+-----+
|  lili|     ml|   90|
|  lucy|     ml|   85|
|cherry|     ml|   80|
| terry|     ml|   85|
| tracy|     cs|   82|
|  tony|     cs|   86|
|   tom|     cs|   75|
+------+-------+-----+

数据构造完毕

3.分组查看排名

经常用到的一个场景是:需要查看每个专业学生的排名,这就是一个典型的分组问题,就是窗口函数大显身手的时候。

一个窗口需要定义三个部分:

1.分组问题,如何将行分组?在选取窗口数据时,只对组内数据生效
2.排序问题,按何种方式进行排序?选取窗口数据时,会首先按指定方式排序
3.帧(frame)选取,以当前行为基准,如何选取周围行?

对照上面的三个部分,窗口函数的语法一般为:

window_func(args) OVER ( [PARTITION BY col_name, col_name, ...] [ORDER BY col_name, col_name, ...] [ROWS | RANGE BETWEEN (CURRENT ROW | (UNBOUNDED |[num]) PRECEDING) AND (CURRENT ROW | ( UNBOUNDED | [num]) FOLLOWING)] )

其中
window_func就是窗口函数
over表示这是个窗口函数
partition by对应的就是分组,即按照什么列分组
order by对应的是排序,按什么列排序
rows则对应的帧选取。

spark中的window_func包括下面三类:
1.排名函数(ranking function) 包括rank,dense_rank, row_number,percent_rank, ntile等,后面我们结合例子来看。
2.分析函数 (analytic functions) 包括cume_dist,lag等。
3.聚合函数(aggregate functions),就是我们常用的max, min, sum, avg等。

回到上面的需求,查看每个专业学生的排名

  def test() = {
    val sparkConf = new SparkConf().setMaster("local[2]")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    val sqlContext = spark.sqlContext


    val data = Array(("lili", "ml", 90),
      ("lucy", "ml", 85),
      ("cherry", "ml", 80),
      ("terry", "ml", 85),
      ("tracy", "cs", 82),
      ("tony", "cs", 86),
      ("tom", "cs", 75))

    val schemas = Seq("name", "subject", "score")
    val df = spark.createDataFrame(data).toDF(schemas: _*)
    df.createOrReplaceTempView("person_subject_score")

    val sqltext = "select name, subject, score, rank() over (partition by subject order by score desc) as rank from person_subject_score";
    val ret = sqlContext.sql(sqltext)
    ret.show()
  }

上面的代码run起来,结果如下

+------+-------+-----+----+
|  name|subject|score|rank|
+------+-------+-----+----+
|  tony|     cs|   86|   1|
| tracy|     cs|   82|   2|
|   tom|     cs|   75|   3|
|  lili|     ml|   90|   1|
|  lucy|     ml|   85|   2|
| terry|     ml|   85|   2|
|cherry|     ml|   80|   4|
+------+-------+-----+----+

重点看下窗口部分:

rank() over (partition by subject order by score desc) as rank

rank()函数表示取每行在分组中的排名,partition by subject表示按subject分组,order by score desc表示按分数排序并且逆序,这样就可以得到每个学生在本专业中的排名!

row_number, dense_rank也都是排序有关的窗口函数,下面我们通过实例看看他们的区别:

  def test() = {
    val sparkConf = new SparkConf().setMaster("local[2]")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    val sqlContext = spark.sqlContext


    val data = Array(("lili", "ml", 90),
      ("lucy", "ml", 85),
      ("cherry", "ml", 80),
      ("terry", "ml", 85),
      ("tracy", "cs", 82),
      ("tony", "cs", 86),
      ("tom", "cs", 75))

    val schemas = Seq("name", "subject", "score")
    val df = spark.createDataFrame(data).toDF(schemas: _*)
    df.createOrReplaceTempView("person_subject_score")

    val sqltext = "select name, subject, score, rank() over (partition by subject order by score desc) as rank from person_subject_score";
    val ret = sqlContext.sql(sqltext)
    ret.show()

    val sqltext2 = "select name, subject, score, row_number() over (partition by subject order by score desc) as row_number from person_subject_score";
    val ret2 = sqlContext.sql(sqltext2)
    ret2.show()

    val sqltext3 = "select name, subject, score, dense_rank() over (partition by subject order by score desc) as dense_rank from person_subject_score";
    val ret3 = sqlContext.sql(sqltext3)
    ret3.show()
  }
+------+-------+-----+----+
|  name|subject|score|rank|
+------+-------+-----+----+
|  tony|     cs|   86|   1|
| tracy|     cs|   82|   2|
|   tom|     cs|   75|   3|
|  lili|     ml|   90|   1|
|  lucy|     ml|   85|   2|
| terry|     ml|   85|   2|
|cherry|     ml|   80|   4|
+------+-------+-----+----+

+------+-------+-----+----------+
|  name|subject|score|row_number|
+------+-------+-----+----------+
|  tony|     cs|   86|         1|
| tracy|     cs|   82|         2|
|   tom|     cs|   75|         3|
|  lili|     ml|   90|         1|
|  lucy|     ml|   85|         2|
| terry|     ml|   85|         3|
|cherry|     ml|   80|         4|
+------+-------+-----+----------+

+------+-------+-----+----------+
|  name|subject|score|dense_rank|
+------+-------+-----+----------+
|  tony|     cs|   86|         1|
| tracy|     cs|   82|         2|
|   tom|     cs|   75|         3|
|  lili|     ml|   90|         1|
|  lucy|     ml|   85|         2|
| terry|     ml|   85|         2|
|cherry|     ml|   80|         3|
+------+-------+-----+----------+

通过上面的例子不难看出这三者的区别:
rank生成不连续的序号,上面的例子是1,2,2,4这种
dense_rank生成连续的序号,上面的例子是1,2,2,3这种
row_number顾名思义,生成的是行号,上面的例子是1,2,3,4这种。
不用去死抠函数的定义,看上面的例子就明白了!

4.查看分位数

下面再看个实例,我们想查看某个人在该专业的分位数,该怎么办?
这个时候就可以用到cume_dist函数了。
该函数的计算方式为:组内小于等于当前行值的行数/组内总行数

还是看代码

    val sqltext5 = "select name, subject, score, cume_dist() over (partition by subject order by score desc) as cumedist from person_subject_score";
    val ret5 = sqlContext.sql(sqltext5)
    ret5.show()

结合前面的数据初始化代码与上面的sql逻辑,最后的结果如下:

+------+-------+-----+------------------+
|  name|subject|score|          cumedist|
+------+-------+-----+------------------+
|  tony|     cs|   86|0.3333333333333333|
| tracy|     cs|   82|0.6666666666666666|
|   tom|     cs|   75|               1.0|
|  lili|     ml|   90|              0.25|
|  lucy|     ml|   85|              0.75|
| terry|     ml|   85|              0.75|
|cherry|     ml|   80|               1.0|
+------+-------+-----+------------------+

可以看到完美满足上面的需求。

5.使用DataFrame的API完成窗口查询

上面的例子使用的是SqlContext的API,在DataFrame中,也有对应的API可以完成查询,具体方式也很简单,使用DataFrame API在支持的函数调用over()方法即可,例如rank().over(…)

拿前面的需求为例,如果我们想查看学生在专业的排名,使用DataFrame的API如下:

  def test() = {
    val sparkConf = new SparkConf().setMaster("local[2]")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()

    val data = Array(("lili", "ml", 90),
      ("lucy", "ml", 85),
      ("cherry", "ml", 80),
      ("terry", "ml", 85),
      ("tracy", "cs", 82),
      ("tony", "cs", 86),
      ("tom", "cs", 75))

    val schemas = Seq("name", "subject", "score")
    val df = spark.createDataFrame(data).toDF(schemas: _*)
    df.createOrReplaceTempView("person_subject_score")

    val window = Window.partitionBy("subject").orderBy(col("score").desc)
    val df2 = df.withColumn("rank", rank().over(window))
    df2.show()
  }

输出结果如下:

+------+-------+-----+----+
|  name|subject|score|rank|
+------+-------+-----+----+
|  tony|     cs|   86|   1|
| tracy|     cs|   82|   2|
|   tom|     cs|   75|   3|
|  lili|     ml|   90|   1|
|  lucy|     ml|   85|   2|
| terry|     ml|   85|   2|
|cherry|     ml|   80|   4|
+------+-------+-----+----+

猜你喜欢

转载自blog.csdn.net/bitcarmanlee/article/details/113617901