class pyspark.sql.GroupedData类是由groupby分组聚合之后产生的,groupby大家应该都很熟悉,广泛存在于sql和pandas中,下面说说其在pyspark中的用法。
agg(*exprs)
在pyspark.sql.DataFrame中讲过agg(),本节中的agg()是分组聚合下的用法。
聚合函数主要分为两种。
- 内置聚合函数,例如avg,max,min,sum,count
- 使用pyspark.sql.functions.pandas_udf()创建的组聚合pandas UDF函数
- 内置聚合函数和组聚合pandas UDF不能在对此函数的单个调用中混合使用。
>>> gdf = df.groupBy(df.name)
>>> sorted(gdf.agg({"*": "count"}).collect())
[Row(name='Alice', count(1)=1), Row(name='Bob', count(1)=1)]
>>> from pyspark.sql import functions as F
>>> sorted(gdf.agg(F.min(df.age)).collect())
[Row(name='Alice', min(age)=2), Row(name='Bob', min(age)=5)]
pandas UDF在下一节中会详细讲,这里暂且看看吧
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('int', PandasUDFType.GROUPED_AGG)
def min_udf(v):
return v.min()
sorted(gdf.agg(min_udf(df.age)).collect())
#[Row(name='Alice', min_udf(age)=2), Row(name='Bob', min_udf(age)=5)]
apply(udf)
使用pandas udf映射当前DataFrame的每个组,并将结果作为DataFrame返回。用户定义的函数应该使用pandas.DataFrame并返回另一个pandas.DataFrame。 对于每个组,所有列作为pandas.DataFrame一起传递给用户函数,返回的pandas.DataFrame组合为DataFrame。返回的pandas.DataFrame可以是任意长度,其模式必须与pandas udf的returnType匹配。
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> :pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
>>> df.groupby("id").apply(normalize).show()
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
avg()、count()等
avg()、count()分别是计算平均数和计数,类似其他的还有max()、min()、sum()等,比较简单,这里就举个简例说明一下,其他的就不赘述了。
>>> df.groupBy().avg('age').collect()
[Row(avg(age)=3.5)]
>>> sorted(df.groupBy(df.age).count().collect())
[Row(age=2, count=1), Row(age=5, count=1)]
pivot(pivot_col, values=None)
熟悉pandas的人应该知道,pandas中有pivot()和pivot_table(),制作透视表的,pyspark中也有相应的函数pivot().
旋转当前DataFrame的列并执行指定的聚合。 有两个版本的pivot函数:一个需要调用者指定要透视的不同值列表,另一个不需要。 后者更简洁但效率更低,因为Spark需要首先在内部计算不同值的列表。
- 数据长这样。
>>>import pyspark.sql.functions as fn
>>>df=spark.createDataFrame([("month1","task1",50),("month1","task2",40),
("month2","task2",80),("month1","task1",45),
("month2","task3",25),("month3","task2",55),
("month2","task2",35)],["month","work","earning"])
>>>df.show()
+------+-----+-------+
| month| work|earning|
+------+-----+-------+
|month1|task1| 50|
|month1|task2| 40|
|month2|task2| 80|
|month1|task1| 45|
|month2|task3| 25|
|month3|task2| 55|
|month2|task2| 35|
+------+-----+-------+
- 按第一种方式聚合,指定透视列的不同值列表。
>>>df.groupBy(df.month).pivot("work",["task1","task2"])\
.agg(fn.sum(df.earning)).fillna(0).show()
+------+-----+-----+
| month|task1|task2|
+------+-----+-----+
|month3| 0| 55|
|month1| 95| 40|
|month2| 0| 115|
+------+-----+-----+
- 按第二种方式聚合,不指定透视列的不同值列表。
>>>df.groupBy(df.month).pivot("work")\
.agg(fn.mean(df.earning)).fillna(0).show()
+------+-----+-----+-----+
| month|task1|task2|task3|
+------+-----+-----+-----+
|month3| 0.0| 55.0| 0.0|
|month1| 47.5| 40.0| 0.0|
|month2| 0.0| 57.5| 25.0|
+------+-----+-----+-----+
到这里,pyspark.sql.GroupedData的基本API函数就讲完了,有兴趣的同学可以将pyspark中pivot()和pandas中的pivot()做个对比,你会发现二者的差异性。