xgboost之spark上运行-scala接口

概述

xgboost可以在spark上运行,我用的xgboost的版本是0.7的版本,目前只支持spark2.0以上版本上运行,

编译好jar包,加载到maven仓库里面去:

 
  
  1. mvn install:install-file -Dfile=xgboost4j-spark-0.7-jar-with-dependencies.jar -DgroupId=ml.dmlc -DartifactId=xgboost4j-spark -Dversion=0.7 -Dpackaging=jar


添加依赖:

[html]  view plain  copy
  1. <dependency>  
  2.             <groupId>ml.dmlc</groupId>  
  3.             <artifactId>xgboost4j-spark</artifactId>  
  4.             <version>0.7</version>  
  5.         </dependency>  
  6.         <dependency>  
  7.             <groupId>org.apache.spark</groupId>  
  8.             <artifactId>spark-core_2.10</artifactId>  
  9.             <version>2.0.0</version>  
  10.         </dependency>  
  11.         <dependency>  
  12.             <groupId>org.apache.spark</groupId>  
  13.             <artifactId>spark-mllib_2.10</artifactId>  
  14.             <version>2.0.0</version>  
  15.         </dependency>  
  16.     </dependencies>  




RDD接口:


[python]  view plain  copy
  1. package com.meituan.spark_xgboost  
  2. import org.apache.log4j.{ Level, Logger }  
  3. import org.apache.spark.{ SparkConf, SparkContext }  
  4. import ml.dmlc.xgboost4j.scala.spark.XGBoost  
  5. import org.apache.spark.sql.{ SparkSession, Row }  
  6. import org.apache.spark.mllib.util.MLUtils  
  7. import org.apache.spark.ml.feature.LabeledPoint  
  8. import org.apache.spark.ml.linalg.Vectors  
  9. object XgboostR {  
  10.   
  11.   
  12.   def main(args: Array[String]): Unit = {  
  13.     Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)  
  14.     Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)  
  15.     val spark = SparkSession.builder.master("local").appName("example").  
  16.       config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").  
  17.       config("spark.sql.shuffle.partitions""20").getOrCreate()  
  18.     spark.conf.set("spark.serializer""org.apache.spark.serializer.KryoSerializer")  
  19.       val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"  
  20.   val trainString = "agaricus.txt.train"  
  21.   val testString = "agaricus.txt.test"  
  22.     val train = MLUtils.loadLibSVMFile(spark.sparkContext, path + trainString)  
  23.     val test = MLUtils.loadLibSVMFile(spark.sparkContext, path + testString)  
  24.     val traindata = train.map { x =>  
  25.       val f = x.features.toArray  
  26.       val v = x.label  
  27.       LabeledPoint(v, Vectors.dense(f))  
  28.     }  
  29.     val testdata = test.map { x =>  
  30.       val f = x.features.toArray  
  31.       val v = x.label  
  32.        Vectors.dense(f)  
  33.     }  
  34.       
  35.   
  36.     val numRound = 15  
  37.       
  38.      //"objective" -> "reg:linear", //定义学习任务及相应的学习目标  
  39.       //"eval_metric" -> "rmse", //校验数据所需要的评价指标  用于做回归  
  40.       
  41.     val paramMap = List(  
  42.       "eta" -> 1f,  
  43.       "max_depth" ->5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞]   
  44.       "silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0   
  45.       "objective" -> "binary:logistic", //定义学习任务及相应的学习目标  
  46.       "lambda"->2.5,  
  47.       "nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数  
  48.       ).toMap  
  49.     println(paramMap)  
  50.       
  51.   
  52.     val model = XGBoost.trainWithRDD(traindata, paramMap, numRound, 55, null, null, useExternalMemory = false, Float.NaN)  
  53.     print("sucess")  
  54.    
  55.     val result=model.predict(testdata)  
  56.     result.take(10).foreach(println)  
  57.     spark.stop();  
  58.      
  59.   }  
  60.   
  61. }  


DataFrame接口:

[python]  view plain  copy
  1. package com.meituan.spark_xgboost  
  2. import org.apache.log4j.{ Level, Logger }  
  3. import org.apache.spark.{ SparkConf, SparkContext }  
  4. import ml.dmlc.xgboost4j.scala.spark.XGBoost  
  5. import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics  
  6. import org.apache.spark.sql.{ SparkSession, Row }  
  7. object XgboostD {  
  8.   def main(args: Array[String]): Unit = {  
  9.     Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)  
  10.     Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)  
  11.     val spark = SparkSession.builder.master("local").appName("example").  
  12.       config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").  
  13.       config("spark.sql.shuffle.partitions""20").getOrCreate()  
  14.     spark.conf.set("spark.serializer""org.apache.spark.serializer.KryoSerializer")  
  15.     val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"  
  16.     val trainString = "agaricus.txt.train"  
  17.     val testString = "agaricus.txt.test"  
  18.   
  19.     val train = spark.read.format("libsvm").load(path + trainString).toDF("label""feature")  
  20.   
  21.     val test = spark.read.format("libsvm").load(path + testString).toDF("label""feature")  
  22.   
  23.     val numRound = 15  
  24.   
  25.     //"objective" -> "reg:linear", //定义学习任务及相应的学习目标  
  26.     //"eval_metric" -> "rmse", //校验数据所需要的评价指标  用于做回归  
  27.   
  28.     val paramMap = List(  
  29.       "eta" -> 1f,  
  30.       "max_depth" -> 5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞]   
  31.       "silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0   
  32.       "objective" -> "binary:logistic", //定义学习任务及相应的学习目标  
  33.       "lambda" -> 2.5,  
  34.       "nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数  
  35.       ).toMap  
  36.     val model = XGBoost.trainWithDataFrame(train, paramMap, numRound, 45, obj = null, eval = null, useExternalMemory = false, Float.NaN, "feature""label")  
  37.     val predict = model.transform(test)  
  38.   
  39.     val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)  
  40.       .rdd  
  41.       .map { case Row(score: Double, label: Double) => (score, label) }  
  42.   
  43.     //get the auc  
  44.     val metric = new BinaryClassificationMetrics(scoreAndLabels)  
  45.     val auc = metric.areaUnderROC()  
  46.     println("auc:" + auc)  
  47.   
  48.   }  
  49.   
  50. }  

猜你喜欢

转载自blog.csdn.net/u010159842/article/details/80198264
今日推荐