I have a table with 3 columns:
Table A:
+----+----+----------+
|col1|col2|row_number|
+----+----+----------+
| X| 1| 1|
| Y| 0| 2|
| Z| 2| 3|
| A| 1| 4|
| B| 0| 5|
| C| 0| 6|
| D| 2| 7|
| P| 1| 8|
| Q| 2| 9|
+----+----+----------+
I want to concatenate the strings in "col1" by grouping records based on the "col2" values. "col2" has a pattern of 1 followed by any number of 0s, followed by 2. I want to group records that have "col2" start with 1 and end with 2 (The order of the data frame must be maintained - you can use the row_number column for the order)
For example, The first 3 records can be grouped together because "col2" has "1-0-2". The next 4 records can be grouped together because their "col2" values have "1-0-0-2"
The concatenating part can be done using "concat_ws" after I group these records. But any help on how to group these records based on the "1-0s-2" pattern?
Expected output:
+----------+
|output_col|
+----------+
| XYZ|
| ABCD|
| PQ|
+----------+
You can use the following code to create this sample data:
schema = StructType([StructField("col1", StringType())\
,StructField("col2", IntegerType())\
,StructField("row_number", IntegerType())])
data = [['X', 1, 1], ['Y', 0, 2], ['Z', 2, 3], ['A', 1, 4], ['B', 0, 5], ['C', 0, 6], ['D', 2, 7], ['P', 1, 8], ['Q', 2, 9]]
df = spark.createDataFrame(data,schema=schema)
df.show()
I would suggest you to use window
functions. First use a window ordered by row_number
to get an incremental sum of col2
. The incremental sum
will have multiples of 3 which will be basically be the endpoints of the group that you need. Replace them with the lag of the same window, to get your desired partitions in incremental_sum
. Now you can groupBy incremental_sum
column and collect_list
. You can array_join
(spark2.4) on the collected list, to get your desired strings.
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w=Window().orderBy("row_number")
df.withColumn("incremental_sum", F.sum("col2").over(w))\
.withColumn("lag", F.lag("incremental_sum").over(w))\
.withColumn("incremental_sum", F.when(F.col("incremental_sum")%3==0, F.col("lag")).otherwise(F.col("incremental_sum")))\
.groupBy("incremental_sum").agg(F.array_join(F.collect_list("col1"),"").alias("output_col")).drop("incremental_sum").show()
+----------+
|output_col|
+----------+
| XYZ|
| ABCD|
| PQ|
+----------+