SparkML -- LightGBM On Spark 二分类LightGBMClassifier示例

MAVEN

<dependency>
     <groupId>com.microsoft.ml.spark</groupId>
     <artifactId>mmlspark_2.11</artifactId>
     <version>0.18.0</version>
 </dependency>
 <dependency>
     <groupId>com.microsoft.ml.lightgbm</groupId>
     <artifactId>lightgbmlib</artifactId>
     <version>2.2.350</version>
 </dependency>

测试数据

http://archive.ics.uci.edu/ml/machine-learning-databases/00275/Bike-Sharing-Dataset.zip

hour.csv和day.csv都有如下属性,除了hour.csv文件中没有hr属性以外

  • instant: 记录ID
  • dteday : 时间日期
  • season : 季节 (1:春季, 2:夏季, 3:秋季, 4:冬季)
  • yr : 年份 (0: 2011, 1:2012)
  • mnth : 月份 ( 1 to 12)
  • hr : 当天时刻 (0 to 23)
  • holiday : 当天是否是节假日(extracted from http://dchr.dc.gov/page/holiday-schedule)
  • weekday : 周几
  • workingday : 工作日 is 1, 其他 is 0.
  • weathersit : 天气
  • 1: Clear, Few clouds, Partly cloudy, Partly cloudy
  • 2: Mist + Cloudy, Mist + Broken clouds, Mist + Few clouds, Mist
  • 3: Light Snow, Light Rain + Thunderstorm + Scattered clouds, Light Rain + Scattered clouds
  • 4: Heavy Rain + Ice Pallets + Thunderstorm + Mist, Snow + Fog
  • temp : 气温 Normalized temperature in Celsius. The values are divided to 41 (max)
  • atemp: 体感温度 Normalized feeling temperature in Celsius. The values are divided to 50 (max)
  • hum: 湿度 Normalized humidity. The values are divided to 100 (max)
  • windspeed: 风速Normalized wind speed. The values are divided to 67 (max)
  • casual: 临时用户数count of casual users
  • registered: 注册用户数count of registered users
  • cnt: 目标变量,每小时的自行车的租用量,包括临时用户和注册用户count of total rental bikes including both casual and registered

代码示例

package com.bigblue.lightgbm

import java.io.FileOutputStream

import com.bigblue.utils.LightGBMUtils
import com.microsoft.ml.spark.lightgbm.{LightGBMClassificationModel, LightGBMClassifier}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.jpmml.lightgbm.GBDT
import org.jpmml.model.MetroJAXBUtil

/**
 * Created By TheBigBlue on 2020/3/6
 * Description :
 */
object LightGBMClassificationTest {

  def main(args: Array[String]): Unit = {

    val spark: SparkSession = SparkSession.builder().appName("test-lightgbm").master("local[2]").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    var originalData: DataFrame = spark.read.option("header", "true") //第一行作为Schema
      .option("inferSchema", "true") //推测schema类型
      //      .csv("/home/hdfs/hour.csv")
      .csv("file:///D:/Cache/ProgramCache/TestData/dataSource/lightgbm/hour.csv")

    val labelCol = "workingday"
    //离散列
    val cateCols = Array("season", "yr", "mnth", "hr")
    // 连续列
    val conCols: Array[String] = Array("temp", "atemp", "hum", "casual", "cnt")
    //feature列
    val vecCols = conCols ++ cateCols

    import spark.implicits._
    vecCols.foreach(col => {
      originalData = originalData.withColumn(col, $"$col".cast(DoubleType))
    })
    originalData = originalData.withColumn(labelCol, $"$labelCol".cast(IntegerType))

    val assembler = new VectorAssembler().setInputCols(vecCols).setOutputCol("features")

    val classifier: LightGBMClassifier = new LightGBMClassifier().setNumIterations(100).setNumLeaves(31)
      .setBoostFromAverage(false).setFeatureFraction(1.0).setMaxDepth(-1).setMaxBin(255)
      .setLearningRate(0.1).setMinSumHessianInLeaf(0.001).setLambdaL1(0.0).setLambdaL2(0.0)
      .setBaggingFraction(1.0).setBaggingFreq(0).setBaggingSeed(1).setObjective("binary")
      .setLabelCol(labelCol).setCategoricalSlotNames(cateCols).setFeaturesCol("features")
      .setBoostingType("gbdt")	//rf、dart、goss

    val pipeline: Pipeline = new Pipeline().setStages(Array(assembler, classifier))

    val Array(tr, te) = originalData.randomSplit(Array(0.7, .03), 666)
    val model = pipeline.fit(tr)
    val modelDF = model.transform(te)
    val evaluator = new BinaryClassificationEvaluator().setLabelCol(labelCol).setRawPredictionCol("prediction")
    println(evaluator.evaluate(modelDF)) 
  }
}

结果

在这里插入图片描述

发布了73 篇原创文章 · 获赞 18 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/Aeve_imp/article/details/105048534