How to convert Row in partition

BdEngineer :

I have a scenario in spark. Have to partition data frame. resultant should be processed by each partition at a time.

List<String> data = Arrays.asList("con_dist_1", "con_dist_2", 
        "con_dist_3", "con_dist_4", "con_dist_5",
        "con_dist_6");
Dataset<Row> codes = sparkSession.createDataset(data, Encoders.STRING());
Dataset<Row> partitioned_codes = codes.repartition(col("codes"));

// I need to paritition it dues to functional requirement
partitioned_codes.foreachPartition(itr -> {
    if (itr.hasNext()) {
        Row inrow = itr.next();
        System.out.println("inrow.length : " + inrow.length());
        System.out.println(inrow.toString());
        List<Object> objs = inrow.getList(0);
    }
});

Getting error

Caused by: java.lang.ClassCastException: java.lang.String cannot be cast to scala.collection.Seq
    at org.apache.spark.sql.Row$class.getSeq(Row.scala:283)
    at org.apache.spark.sql.catalyst.expressions.GenericRow.getSeq(rows.scala:166)
    at org.apache.spark.sql.Row$class.getList(Row.scala:291)
    at org.apache.spark.sql.catalyst.expressions.GenericRow.getList(rows.scala:166)

Question : How to handle foreachPartition here, where itr each iteration consists a group of Rows, how to get those rows using itr?

Test 1:

inrow.length: 0
[]
inrow.length: 0
[]
2020-03-02 05:22:14,179 [Executor task launch worker for task 615] ERROR org.apache.spark.executor.Executor - Exception in task 110.0 in stage 21.0 (TID 615)
java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.String
    at org.apache.spark.sql.Row$class.getString(Row.scala:255)
    at org.apache.spark.sql.catalyst.expressions.GenericRow.getString(rows.scala:166)

Output-1 :

inrow.length: 0
[]
inrow.length: 0
[]
inrow.length: 1
[con_dist_1]
inrow.length: 1
[con_dist_2]
inrow.length: 1
[con_dist_5]
inrow.length: 1
[con_dist_6]
inrow.length: 1
[con_dist_4]
inrow.length: 1
[con_dist_3]
ernest_k :

All the rows of the partition are in itr. So when you call itr.next(), you only get the first row. If you need to print all the rows, you can use a while loop, or you can convert the iterator to a list with something like this (I suspect this is what you wanted to get to):

partitioned_codes.foreachPartition(itr -> {
    Iterable<Row> rowIt = () -> itr;
    List<String> objs = StreamSupport.stream(rowIt.spliterator(), false)
            .map(row -> row.getString(0))
            .collect(Collectors.toList());

    System.out.println("inrow.length: " + objs.size());
    System.out.println(objs);
});

The example code you posted didn't compile for me, so here's the version I tested with:

List<String> data = Arrays.asList("con_dist_1", "con_dist_2", 
        "con_dist_3", "con_dist_4", "con_dist_5",
        "con_dist_6");
StructType struct = new StructType()
        .add(DataTypes.createStructField("codes", DataTypes.StringType, true));
Dataset<Row> codes = sparkSession.createDataFrame(sc.parallelize(data, 2)
                        .map(s -> RowFactory.create(s)), struct);
Dataset<Row> partitioned_codes = codes.repartition(org.apache.spark.sql.functions.col("codes"));

partitioned_codes.foreachPartition(itr -> {
    Iterable<Row> rowIt = () -> itr;
    List<String> objs = StreamSupport.stream(rowIt.spliterator(), false)
            .map(row -> row.getString(0))
            .collect(Collectors.toList());

    System.out.println("inrow.length: " + objs.size());
    System.out.println(objs);
});

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=19084&siteId=1