Grouping records based on a pattern of row values using PySpark

Surya Murali :

I have a table with 3 columns:

Table A:

+----+----+----------+                                                          
|col1|col2|row_number|
+----+----+----------+
|   X|   1|         1|
|   Y|   0|         2|
|   Z|   2|         3|
|   A|   1|         4|
|   B|   0|         5|
|   C|   0|         6|
|   D|   2|         7|
|   P|   1|         8|
|   Q|   2|         9|
+----+----+----------+

I want to concatenate the strings in "col1" by grouping records based on the "col2" values. "col2" has a pattern of 1 followed by any number of 0s, followed by 2. I want to group records that have "col2" start with 1 and end with 2 (The order of the data frame must be maintained - you can use the row_number column for the order)

For example, The first 3 records can be grouped together because "col2" has "1-0-2". The next 4 records can be grouped together because their "col2" values have "1-0-0-2"

The concatenating part can be done using "concat_ws" after I group these records. But any help on how to group these records based on the "1-0s-2" pattern?

Expected output:

+----------+
|output_col|
+----------+
|       XYZ|   
|      ABCD|   
|        PQ| 
+----------+

You can use the following code to create this sample data:

schema = StructType([StructField("col1", StringType())\
                   ,StructField("col2", IntegerType())\
                   ,StructField("row_number", IntegerType())])

data = [['X', 1, 1], ['Y', 0, 2], ['Z', 2, 3], ['A', 1, 4], ['B', 0, 5], ['C', 0, 6], ['D', 2, 7], ['P', 1, 8], ['Q', 2, 9]]

df = spark.createDataFrame(data,schema=schema)
df.show()
Mohammad Murtaza Hashmi :

I would suggest you to use window functions. First use a window ordered by row_number to get an incremental sum of col2. The incremental sum will have multiples of 3 which will be basically be the endpoints of the group that you need. Replace them with the lag of the same window, to get your desired partitions in incremental_sum. Now you can groupBy incremental_sum column and collect_list. You can array_join(spark2.4) on the collected list, to get your desired strings.

from pyspark.sql import functions as F 
from pyspark.sql.window import Window
w=Window().orderBy("row_number")
df.withColumn("incremental_sum", F.sum("col2").over(w))\
  .withColumn("lag", F.lag("incremental_sum").over(w))\
  .withColumn("incremental_sum", F.when(F.col("incremental_sum")%3==0, F.col("lag")).otherwise(F.col("incremental_sum")))\
  .groupBy("incremental_sum").agg(F.array_join(F.collect_list("col1"),"").alias("output_col")).drop("incremental_sum").show()
+----------+
|output_col|
+----------+
|       XYZ|
|      ABCD|
|        PQ|
+----------+

Guess you like

Origin http://10.200.1.11:23101/article/api/json?id=395187&siteId=1