I have written Scala code in spark-shell to map one column of a dataframe to another. I am now attempting to convert it to Java but am having difficulties with the UDF I defined.
I am taking this data frame:
+------+-----+-----+
|acctId|vehId|count|
+------+-----+-----+
| 1| 777| 3|
| 2| 777| 1|
| 1| 666| 1|
| 1| 999| 3|
| 1| 888| 2|
| 3| 777| 4|
| 2| 999| 1|
| 3| 888| 2|
| 2| 888| 3|
+------+-----+-----+
And converting it to this:
+------+----------------------------------------+
|acctId|vehIdToCount |
+------+----------------------------------------+
|1 |[777 -> 3, 666 -> 1, 999 -> 3, 888 -> 2]|
|3 |[777 -> 4, 888 -> 2] |
|2 |[777 -> 1, 999 -> 1, 888 -> 3] |
+------+----------------------------------------+
I am doing this via these commands. First, my UDF to map a list of row values from one column to a second column:
val listToMap = udf((input: Seq[Row]) => input.map(row => (row.getAs[Long](0), row.getAs[Long](1))).toMap)
I am doing this via a double groupBy/aggregation:
val resultDF = testData.groupBy("acctId", "vehId")
.agg(count("acctId").cast("long").as("count"))
.groupBy("acctId")
.agg(collect_list(struct("vehId", "count")) as ("vehIdToCount"))
.withColumn("vehIdToCount", listToMap($"map"))
My problem is in trying to write the listToMap UDF in Java. I am fairly new to both Scala and Java so I may just be missing something.
I was hoping I could do something as simple as:
UserDefinedFunction listToMap = udf(
(Seq<Dataset<Row>> input) -> input.map(r -> (r.get(“vehicleId”), r.get(“count”)));
);
But I can not identify a valid method to get each of these columns, even after looking fairly extensively through the documentation. I have tried just doing a SELECT too but that is not working either.
Any help is much appreciated. For your reference this is how I’m generating my test data in spark-shell:
val testData = Seq(
(1, 999),
(1, 999),
(2, 999),
(1, 888),
(2, 888),
(3, 888),
(2, 888),
(2, 888),
(1, 888),
(1, 777),
(1, 666),
(3, 888),
(1, 777),
(3, 777),
(2, 777),
(3, 777),
(3, 777),
(1, 999),
(3, 777),
(1, 777)
).toDF("acctId", "vehId”)
I can't help you write the UDF, but I can show you how to avoid it using Spark's built-in map_from_entries
function. UDFs should always be a path of last resort, both to keep your codebase simple and because Spark cannot optimize them. The below example is in Scala but should be trivial to translate:
scala> val testData = Seq(
| (1, 999),
| (1, 999),
| (2, 999),
| (1, 888),
| (2, 888),
| (3, 888),
| (2, 888),
| (2, 888),
| (1, 888),
| (1, 777),
| (1, 666),
| (3, 888),
| (1, 777),
| (3, 777),
| (2, 777),
| (3, 777),
| (3, 777),
| (1, 999),
| (3, 777),
| (1, 777)
| ).toDF("acctId", "vehId")
testData: org.apache.spark.sql.DataFrame = [acctId: int, vehId: int]
scala>
scala> val withMap = testData.groupBy('acctId, 'vehId).
| count.
| select('acctId, struct('vehId, 'count).as("entries")).
| groupBy('acctId).
| agg(map_from_entries(collect_list('entries)).as("myMap"))
withMap: org.apache.spark.sql.DataFrame = [acctId: int, myMap: map<int,bigint>]
scala>
scala> withMap.show(false)
+------+----------------------------------------+
|acctId|myMap |
+------+----------------------------------------+
|1 |[777 -> 3, 666 -> 1, 999 -> 3, 888 -> 2]|
|3 |[777 -> 4, 888 -> 2] |
|2 |[777 -> 1, 999 -> 1, 888 -> 3] |
+------+----------------------------------------+