机器学习实战(五)03-Spark-SVM

官方文档示例

package com.netcloud.bigdata.mllib.com.svm.example

import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}

/*
  * @Author: yangshaojun
  * @Date: 2020/02/17 20:57
  * @Version 1.0
  * 官方文档示例
  */
object SVMWithSGDExample {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("SVMWithSGDExample").setMaster("local")
    val sc = new SparkContext(conf)
    // $example on$
    // Load training data in LIBSVM format.
    val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
    val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)
    val numIterations = 200
    val model = SVMWithSGD.train(training, numIterations)
    model.clearThreshold()
    val scoreAndLabels = test.map { point =>
      val score = model.predict(point.features)
      (score, point.label)
    }
    val metrics = new BinaryClassificationMetrics(scoreAndLabels)
    val auROC = metrics.areaUnderROC()
    println(s"Area under ROC = $auROC")


  }
}

自定义SVMWithSGD 参数

package com.netcloud.bigdata.mllib.com.svm.action

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.optimization.{HingeGradient, SquaredL2Updater}
import org.apache.spark.mllib.util.MLUtils

object SVMAction {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("SVMWithSGDExample").setMaster("local")
    val sc = new SparkContext(conf)
    // $example on$
    // Load training data in LIBSVM format.
    val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
    val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)

    val svm = new SVMWithSGD
    svm.setIntercept(false)
    svm.optimizer.setNumIterations(1000)
    svm.optimizer.setStepSize(1.0)
    svm.optimizer.setRegParam(0.01)
    svm.optimizer.setMiniBatchFraction(1.0)
    svm.optimizer.setConvergenceTol(0.001)
    svm.optimizer.setGradient(new HingeGradient())//new LeastSquaresGradient 或者 new LogisticGradient
    svm.optimizer.setUpdater(new SquaredL2Updater())//new SquaredL2Updater()  或 new L1Updater()
    val model=svm.run(training)
    model.clearThreshold()
    val scoreAndLabels = test.map { point =>
      val score = model.predict(point.features)
      (score, point.label)
    }
    val metrics = new BinaryClassificationMetrics(scoreAndLabels)
    val auROC = metrics.areaUnderROC()
    println(s"Area under ROC = $auROC")

  }

}
发布了110 篇原创文章 · 获赞 22 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/yangshaojun1992/article/details/104360777