自定义Estimator机器学习类

由于没有将类写到包org.apache.spark.ml.feature里,所以很多spark源码里的方法不可以直接调用。如spark2.3以下就不可以直接继承sharedParmas里面的特质。

import org.apache.spark.ml.util._
import org.apache.spark.ml.param._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.Model
import org.apache.spark.sql.{ DataFrame, Dataset }
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkException

trait WOEBase extends Params {
    //从spark2.3开始,可以直接继承sharedParmas里面的相关特质
    final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name")
    final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name")
    final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")
    final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")
    final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names")
    def getLabelCol() = $(labelCol)
    def getInputCol() = $(inputCol)
    def getOutputCol() = $(outputCol)
    def getInputCols() = $(inputCols)
    def getOutputCols() = $(outputCols)

    final val delta: DoubleParam = new DoubleParam(this, "delta", "防止出现0值,造成除0溢出或对数无穷大,而增加的修正值")
    def getDelta() = $(delta)
    setDefault(delta -> 1, labelCol -> "label") //Params类的方法,设置参数默认值

    //此方法将
    protected def getInOutCols: (Array[String], Array[String]) = { 
        //require方法是scala.Predef对象下的预定义方法,判断条件,条件为false则抛出IllegalArgumentException异常
        require(
                //这里用isSet检查参数是否被set方法设置过,默认的参数(通过setDefault设置的参数)并不会返回True,而是False,
                //保证了我们可以给任意列设置默认参数,而此句并不需要修改
            (isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) ||
                (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)),
            "WOE only supports setting either inputCol/outputCol or" +
                "inputCols/outputCols.")

        if (isSet(inputCol)) { //isSet:Params类的方法,检查是否设置了参数值
            (Array($(inputCol)), Array($(outputCol)))
        } else {
            require(
                $(inputCols).length == $(outputCols).length,
                "inputCols number do not match outputCols")
            ($(inputCols), $(outputCols))
        }
    }

    protected def validateAndTransformSchemas(schema: StructType): StructType = {
        //StructField包含字段名称、类型(例如StringType,IntegerType,ArrayType等)、能否为空、和metadata信息
        //StructType包含了多个StructField,一个schema就是一个StructType
        //Dataset支持的数据类型都在org.apache.spark.sql.types包下面,大多都是DataType的子类
        val labelColName = $(labelCol)
        val labelDataType = schema(labelColName).dataType
        require(
            labelDataType.isInstanceOf[NumericType],
            s"The label column $labelColName must be numeric type, " +
                s"but got $labelDataType.")

        val (inputColNames, outputColNames) = getInOutCols
        val existingFields = schema.fields
        var outputFields = existingFields
        inputColNames.zip(outputColNames).foreach {
            case (inputColName, outputColName) =>
                require(
                    existingFields.exists(_.name == inputColName),
                    s"Iutput column ${inputColName} not exists.")
                require(
                    existingFields.forall(_.name != outputColName),
                    s"Output column ${outputColName} already exists.")
                val attr = NominalAttribute.defaultAttr.withName(outputColName)
                outputFields :+= attr.toStructField()
        }
        StructType(outputFields)
    }
}

class WOE(override val uid: String)
    extends Estimator[WOEModel]
    with WOEBase with DefaultParamsWritable {

    def this() = this(Identifiable.randomUID("WOE")) //WOE_ 和一个随机数组成的标识符作为uid

    //set方法
    def setLabelCol(value: String): this.type = set(labelCol, value)
    def setInputCol(value: String): this.type = set(inputCol, value)
    def setOutputCol(value: String): this.type = set(outputCol, value)
    def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
    def setDelta(value: Double): this.type = set(delta, value)

    override def copy(extra: ParamMap): this.type = defaultCopy(extra)  //必须要实现的方法,调用默认defaultCopy方法即可

    override def fit(dataset: Dataset[_]): WOEModel = {  //必须要实现的方法,主要实现逻辑
        transformSchema(dataset.schema, true)  //PipelineStage类的方法,调用本类实现的transformSchema方法,另外布尔值参数决定是否将转换前后的schema信息用logDebug输出
        val delta_value = $(delta) //防止出现0值,而增加的修正
        val T = dataset.count
        //        val B = dataset.agg(sum("y")).first.getLong(0)
        val B = dataset.where($(labelCol) + " = 1").count()
        val G = T - B

        val woe_map_arr = new ArrayBuffer[Map[String, Double]]()
        val (inputColNames, outputColNames) = getInOutCols
        inputColNames.foreach {
            inputColName =>
                val gDs_t = dataset.groupBy(inputColName).agg(count($(labelCol)).as("T"), sum($(labelCol)).as("B"))
                val gDs = gDs_t.withColumn("G", gDs_t("T") - gDs_t("B"))

                val loger = udf { d: Double =>
                    math.log(d)
                }
                val woe_map = gDs.withColumn("woe", loger((gDs("B") + delta_value) / (B + delta_value) * (G + delta_value) / (gDs("G") + delta_value)))
                    .select(col(inputColName).cast(StringType), col("woe"))
                    .collect()
                    .map(r => (r.getString(0), r.getDouble(1)))
                    .toMap
                woe_map_arr += woe_map
        }

        copyValues(new WOEModel(uid, woe_map_arr.toSeq).setParent(this)) //copyValues:Params特质的方法,将parent的参数值拷贝给model(如果model有一样的参数)
    }

    override def transformSchema(schema: StructType): StructType = {  //必须要实现的方法,输出转换后的schema;这个方法如果不做任何事,fit里不掉用应该也可以,未测试
        validateAndTransformSchemas(schema)
    }
}

//save方法调用的就是write.save
//load方法调用的是read.load方法
object WOE extends DefaultParamsReadable[WOE] {
    override def load(path: String): WOE = super.load(path)  //必须要实现的方法,直接用DefaultParamsReadable的,实际上是调用了DefaultParamsReader的load方法
}

//Estimator学习输出Transformer实际上就是传递一个数据结构。
//fit方法会将这个学到的数据结果作为传给Transformer:直接作为构造参数传递或者用设置参数的形式传递都可以,
//这里采用构造参数传递。一般会用参数方法传递:就可以作为参数获取或者设置模型的规则了。
//而且这里简化逻辑没有区分单列转换还是多列转换(inputCol还是inputCols):单列和多列都当作多列来处理。同理,也要重写classs WOE的write方法
class WOEModel(override val uid: String, val woe_map_arr: Seq[Map[String, Double]])
    extends Model[WOEModel]
    with MLWritable with WOEBase {
    def this(woe_map_arr: Seq[Map[String, Double]]) = this(Identifiable.randomUID("WOE"), woe_map_arr)

    def setInputCol(value: String): this.type = set(inputCol, value)
    def setLabelCol(value: String): this.type = set(labelCol, value)
    def setOutputCol(value: String): this.type = set(outputCol, value)
    def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

//    def setDelta(value: Double): this.type = set(delta, value)  //模型不可以设置delta,因为delta只对学习有用,对转换没用

    override def copy(extra: ParamMap): WOEModel = {
        val copied = new WOEModel(uid, woe_map_arr)
        copyValues(copied, extra).setParent(parent) //copyValues方法能够拷贝参数
    }

    import WOEModel._
    override def write: WOEModelWriter = new WOEModelWriter(this)

    override def transform(dataset: Dataset[_]): DataFrame = {
        val (inputColNames, outputColNames) = getInOutCols
        transformSchema(dataset.schema)
        require(
            woe_map_arr.length == inputColNames.length,
            s"The number of input columns is not equal to the number of WOEModel model maps ")

        var df: DataFrame = dataset.toDF()
        woe_map_arr.zipWithIndex.map {
            case (woe_map, idx) =>
                val inputColName = inputColNames(idx)
                val outputColName = outputColNames(idx)
                val woer = udf { (feature: String) =>
                    woe_map.get(feature) match {
                        case Some(n: Double) => n
                        case None =>
                            //这里选择直接抛出异常,之前用return dataset会报错
                            throw new SparkException(s"Input column_${inputColName}'s value ${feature} does not exist in the WOEModel model map. " +
                                "Skip WOEModel.")
                    }
                }//.asNondeterministic() //spark 2.3支持此句
                df = df.withColumn(outputColName, woer(dataset(inputColName).cast(StringType)))
        }

        df
    }

    override def transformSchema(schema: StructType): StructType = {
        validateAndTransformSchemas(schema)
    }
}

object WOEModel extends MLReadable[WOEModel] {
    import org.apache.hadoop.fs.Path
    import org.json4s.JsonDSL._
    //render,compact方法都是这里面的
    import org.json4s.jackson.JsonMethods._
    import org.json4s.JsonAST._

    implicit val format = org.json4s.DefaultFormats

    //这里是自己实现了保存细节,spark源码部分有统一的实现,但是是private[ml]的。
    private[WOEModel] class WOEModelWriter(instance: WOEModel) extends MLWriter {
        private case class Data(woe_map_arr: Seq[Map[String, Double]])
        override protected def saveImpl(path: String): Unit = {
            //这里选择了对所有参数进行保存,因为我们的outputCol并没有设置默认值,所以没问题
            //如果outputCol有默认值,并且设置了inputCols和outputCols参数,保存的时候就要去掉outputCol的默认参数保存:
            //因为,一旦将其默认值也保存,再加载的时候会用set方法设置参数,而不是setDefault,然后调用transform的时候会检查独占参数会报错
            //详情参考SPARK-23377
            val metadataPath = new Path(path, "metadata").toString

            val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
            val jsonParams = render(
                params.map {
                    case ParamPair(p, v) =>
                        p.name -> parse(p.jsonEncode(v))
                }.toList)

            val basicMetadata = ("class" -> instance.getClass.getName) ~
                ("timestamp" -> System.currentTimeMillis()) ~
                ("sparkVersion" -> sc.version) ~
                ("uid" -> instance.uid) ~
                ("paramMap" -> jsonParams)

            val metadataJson = compact(render(basicMetadata))

            sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)

            val data = Data(instance.woe_map_arr)
            val dataPath = new Path(path, "data").toString
            sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
        }
    }

    private class WOEModelReader extends MLReader[WOEModel] {
        private val className = classOf[WOEModel].getName
        override def load(path: String): WOEModel = {
            val metadataPath = new Path(path, "metadata").toString
            val s = sc.textFile(metadataPath, 1).first()

            val metadata = parse(s)
            val clz = (metadata \ "class").extract[String]
            val uid = (metadata \ "uid").extract[String]

            require(className == clz, s"Error loading metadata: Expected class name" +
                s" className but found class name ${clz}")

            val dataPath = new Path(path, "data").toString
            val data = sparkSession.read.parquet(dataPath)
                .select("woe_map_arr")
                .head()
            val woe_map_arr = data.getAs[Seq[Map[String, Double]]](0)
            val instance = new WOEModel(uid, woe_map_arr)

            val params = metadata \ "paramMap"
            params match {
                case JObject(pairs) =>
                    pairs.foreach {
                        case (paramName, jsonValue) =>
                            val param = instance.getParam(paramName)
                            val value = param.jsonDecode(compact(render(jsonValue)))
                            instance.set(param, value)
                    }
                case _ =>
                    throw new IllegalArgumentException(
                        s"Cannot recognize JSON metadata: ${s}.")
            }

            instance
        }
    }

    override def read: MLReader[WOEModel] = new WOEModelReader

    override def load(path: String): WOEModel = super.load(path)
}

猜你喜欢

转载自blog.csdn.net/xuejianbest/article/details/80495795
今日推荐