How to load a custom transformer in Spark 2.4

gmds :

I'm trying to create a custom transformer in Spark 2.4.0. Saving it works fine. However, when I try to load it, I get the following error:

java.lang.NoSuchMethodException: TestTransformer.<init>(java.lang.String)
  at java.lang.Class.getConstructor0(Class.java:3082)
  at java.lang.Class.getConstructor(Class.java:1825)
  at org.apache.spark.ml.util.DefaultParamsReader.load(ReadWrite.scala:496)
  at org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:380)
  at TestTransformer$.load(<console>:40)
  ... 31 elided

This suggests to me that it can't find my transformer's constructor, which doesn't really make sense to me.

MCVE:

import org.apache.spark.sql.{Dataset, DataFrame}
import org.apache.spark.sql.types.{StructType}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}

class TestTransformer(override val uid: String) extends Transformer with DefaultParamsWritable{

    def this() = this(Identifiable.randomUID("TestTransformer"))

    override def transform(df: Dataset[_]): DataFrame = {
        val columns = df.columns
        df.select(columns.head, columns.tail: _*)
    }

    override def transformSchema(schema: StructType): StructType = {
        schema
    }

    override def copy(extra: ParamMap): TestTransformer = defaultCopy[TestTransformer](extra)
}

object TestTransformer extends DefaultParamsReadable[TestTransformer]{

    override def load(path: String): TestTransformer = super.load(path)

}

val transformer = new TestTransformer("test")

transformer.write.overwrite().save("test_transformer")
TestTransformer.load("test_transformer")

Running this (I'm using a Jupyter notebook) leads to the above error. I've tried compiling and running it as a .jar file, with no difference.

What puzzles me is that the equivalent PySpark code works fine:

from pyspark.sql import SparkSession, DataFrame
from pyspark.ml import Transformer
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable

class TestTransformer(Transformer, DefaultParamsWritable, DefaultParamsReadable):

    def transform(self, df: DataFrame) -> DataFrame:
        return df

TestTransformer().save('test_transformer')
TestTransformer.load('test_transformer')

How can I make a custom Spark transformer that can be saved and loaded?

fonkap :

I can reproduce your problem in spark-shell.

Trying to find the source of the problem I looked into DefaultParamsReadable and DefaultParamsReader sources and I could see they utilize Java reflection.

https://github.com/apache/spark/blob/v2.4.0/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala

lines 495-496

val instance =
    cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]

I think scala REPLs and Java reflection aren't good friends.

If you run this snippet (after yours):

new TestTransformer().getClass.getConstructors

you'll get the following output:

res1: Array[java.lang.reflect.Constructor[_]] = Array(public TestTransformer($iw), public TestTransformer($iw,java.lang.String))

It is true! TestTransformer.<init>(java.lang.String) doesn't exist.

I found 2 workarounds,

  1. Compiling your code with sbt and creating a jar, then including in spark-shell with :require, worked for me (You mentioned you tried a jar, I don't know how though)

  2. Pasting the code in spark-shell with :paste -raw , worked fine as well. I suppose -raw prevents from REPL doing shenanigans to your classes. See: https://docs.scala-lang.org/overviews/repl/overview.html

I'm not sure how you can adapt any of these to Jupyter but I hope this info is useful for you.

NOTE: I actually used spark-shell in spark 2.4.1

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=134076&siteId=1