Spark 逻辑回归LogisticRegression

1、概念

逻辑回归是预测分类相应的常用方法。广义线性回归的一个特例是预测结果的概率。在spark.ml逻辑回归中,可以使用二项逻辑回归来预测二元结果,

或者可以使用多项逻辑回归来预测多类结果。使用该family参数在这两种算法之间选择,或者保持不设置(缺省auto),Spark将推断出正确的变量。 通过将family参数设置为“多项式”,可以将多项逻辑回归用于二进制分类。它将产生两组系数和两个截距.
在分类问题中,我们尝试预测的是结果是否属于某一个类(例如正确或错误)。分类问题的例子有:判断一封电子邮件是否是垃圾邮件;判断一次金融交易是否是欺诈;

2、code,参考地址:https://github.com/asker124143222/spark-demo

package com.home.spark.ml

import org.apache.spark.SparkConf
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.{Dataset, Row, SparkSession}

/**
  * @Description: 逻辑回归,二项分类预测
  *
  **/
object Ex_BinomialLogisticRegression {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf(true).setMaster("local[*]").setAppName("spark ml label")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    //rdd转换成df或者ds需要SparkSession实例的隐式转换
    //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名
    import spark.implicits._

    val data = spark.sparkContext.textFile("input/iris.data.txt")
      .map(_.split(","))
      .map(a => Iris(
        Vectors.dense(a(0).toDouble, a(1).toDouble, a(2).toDouble, a(3).toDouble),
        a(4))
      ).toDF()
    data.show()

    data.createOrReplaceTempView("iris")

    val TotalCount = spark.sql("select count(*) from iris")
    println("记录数: " + TotalCount.collect().take(1).mkString)

    //二项预测,由于样本数据有三类数据,排除Iris-setosa
    val df = spark.sql("select * from iris where label!='Iris-setosa'")
    df.map(r => r(1) + " : " + r(0)).collect().take(10).foreach(println)
    println("过滤后的记录数: " + df.count())


    /* VectorIndexer
    提高决策树或随机森林等ML方法的分类效果。
    VectorIndexer是对数据集特征向量中的类别(离散值)特征(index categorical features categorical features )进行编号。
    它能够自动判断那些特征是离散值型的特征,并对他们进行编号,
    具体做法是通过设置一个maxCategories,特征向量中某一个特征不重复取值个数小于maxCategories,则被重新编号为0~K(K<=maxCategories-1)。
    某一个特征不重复取值个数大于maxCategories,则该特征视为连续值,不会重新编号(不会发生任何改变)
    假设maxCategories=5,那么特征列中非重复取值小于等于5的列将被重新索引
    为了索引的稳定性,规定如果这个特征值为0,则一定会被编号成0,这样可以保证向量的稀疏度
    maxCategories缺省是20
    */
    //对特征列和标签列进行索引转换
    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df)
    val featureIndexer = new VectorIndexer()
//      .setMaxCategories(5) //设置为5后,由于特征列的非重复值个数都大于5,所以不会发生任何转换,也就没有意义
      .setInputCol("features").setOutputCol("indexedFeatures")
      .fit(df)


    //对原数据集划分训练数据(70%)和测试数据(30%)
    val Array(trainingData, testData): Array[Dataset[Row]] = df.randomSplit(Array(0.7, 0.3))

    /**
      * LR建模
      * setMaxIter设置最大迭代次数(默认100),具体迭代次数可能在不足最大迭代次数停止
      * setTol设置容错(默认1e-6),每次迭代会计算一个误差,误差值随着迭代次数增加而减小,当误差小于设置容错,则停止迭代
      * setRegParam设置正则化项系数(默认0),正则化主要用于防止过拟合现象,如果数据集较小,特征维数又多,易出现过拟合,考虑增大正则化系数
      * setElasticNetParam正则化范式比(默认0),正则化有两种方式:L1(Lasso)和L2(Ridge),L1用于特征的稀疏化,L2用于防止过拟合
      * setLabelCol设置标签列
      * setFeaturesCol设置特征列
      * setPredictionCol设置预测列
      * setThreshold设置二分类阈值
      */
    //设置逻辑回归参数
    val lr = new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setFamily()
      .setMaxIter(100).setRegParam(0.3).setElasticNetParam(0.8)

    //转换器,将预测的类别重新转成字符型
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predectionLabel")
      .setLabels(labelIndexer.labels)


    //建立工作流
    val lrPipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, lr, labelConverter))

    //生成模型
    val model = lrPipeline.fit(trainingData)

    //预测
    val result = model.transform(testData)

    //打印结果
    result.show(200, false)

    //模型评估,预测准确性和错误率
    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
    val lrAccuracy: Double = evaluator.evaluate(result)

    println("Test Error = " + (1.0 - lrAccuracy))

    spark.stop()
  }
}


case class Iris(features: Vector, label: String)

3、result

+-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+
|features         |label          |indexedLabel|indexedFeatures    |rawPrediction                               |probability                             |prediction|predectionLabel|
+-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+
|[4.9,2.4,3.3,1.0]|Iris-versicolor|0.0         |[4.9,3.0,3.3,0.0]  |[1.0071037675553336,-1.0071037675553336]    |[0.7324529695042751,0.2675470304957249] |0.0       |Iris-versicolor|
|[5.0,2.0,3.5,1.0]|Iris-versicolor|0.0         |[5.0,0.0,3.5,0.0]  |[0.938177922699384,-0.938177922699384]      |[0.7187314594034615,0.2812685405965385] |0.0       |Iris-versicolor|
|[5.6,2.5,3.9,1.1]|Iris-versicolor|0.0         |[5.6,4.0,3.9,1.0]  |[0.7107814076350716,-0.7107814076350716]    |[0.6705737993354417,0.3294262006645583] |0.0       |Iris-versicolor|
|[5.6,2.9,3.6,1.3]|Iris-versicolor|0.0         |[5.6,8.0,3.6,3.0]  |[0.6350805242141693,-0.6350805242141693]    |[0.6536405613705153,0.3463594386294846] |0.0       |Iris-versicolor|
|[5.8,2.7,4.1,1.0]|Iris-versicolor|0.0         |[5.8,6.0,4.1,0.0]  |[0.7314003881315354,-0.7314003881315354]    |[0.6751125028597408,0.32488749714025916]|0.0       |Iris-versicolor|
|[6.1,2.8,4.7,1.2]|Iris-versicolor|0.0         |[6.1,7.0,4.7,2.0]  |[0.34553320285886,-0.34553320285886]        |[0.5855339747983552,0.41446602520164466]|0.0       |Iris-versicolor|
|[6.2,2.2,4.5,1.5]|Iris-versicolor|0.0         |[6.2,1.0,4.5,5.0]  |[0.14582457165756946,-0.14582457165756946]  |[0.5363916772629104,0.46360832273708963]|0.0       |Iris-versicolor|
|[6.4,2.9,4.3,1.3]|Iris-versicolor|0.0         |[6.4,8.0,4.3,3.0]  |[0.39384006721834597,-0.39384006721834597]  |[0.597206774507057,0.40279322549294305] |0.0       |Iris-versicolor|
|[6.6,3.0,4.4,1.4]|Iris-versicolor|0.0         |[6.6,9.0,4.4,4.0]  |[0.2698323194379575,-0.2698323194379575]    |[0.5670517391689078,0.43294826083109217]|0.0       |Iris-versicolor|
|[6.7,3.0,5.0,1.7]|Iris-versicolor|0.0         |[6.7,9.0,5.0,7.0]  |[-0.20557969118713126,0.20557969118713126]  |[0.44878532413929256,0.5512146758607075]|1.0       |Iris-virginica |
|[6.7,3.1,4.4,1.4]|Iris-versicolor|0.0         |[6.7,10.0,4.4,4.0] |[0.2698323194379575,-0.2698323194379575]    |[0.5670517391689078,0.43294826083109217]|0.0       |Iris-versicolor|
|[7.0,3.2,4.7,1.4]|Iris-versicolor|0.0         |[7.0,11.0,4.7,4.0] |[0.16644355215403328,-0.16644355215403328]  |[0.5415150896404186,0.4584849103595813] |0.0       |Iris-versicolor|
|[4.9,2.5,4.5,1.7]|Iris-virginica |1.0         |[4.9,4.0,4.5,7.0]  |[-0.033265079047257284,0.033265079047257284]|[0.49168449702809164,0.5083155029719083]|1.0       |Iris-virginica |
|[5.4,3.0,4.5,1.5]|Iris-versicolor|0.0         |[5.4,9.0,4.5,5.0]  |[0.14582457165756946,-0.14582457165756946]  |[0.5363916772629104,0.46360832273708963]|0.0       |Iris-versicolor|
|[5.6,2.8,4.9,2.0]|Iris-virginica |1.0         |[5.6,7.0,4.9,10.0] |[-0.43975124481639627,0.43975124481639627]  |[0.39180024423019144,0.6081997557698086]|1.0       |Iris-virginica |
|[5.6,3.0,4.1,1.3]|Iris-versicolor|0.0         |[5.6,9.0,4.1,3.0]  |[0.4627659120742955,-0.4627659120742955]    |[0.6136701219061476,0.38632987809385244]|0.0       |Iris-versicolor|
|[5.8,2.7,3.9,1.2]|Iris-versicolor|0.0         |[5.8,6.0,3.9,2.0]  |[0.6212365822826582,-0.6212365822826582]    |[0.6504997376392441,0.34950026236075604]|0.0       |Iris-versicolor|
|[5.8,2.7,5.1,1.9]|Iris-virginica |1.0         |[5.8,6.0,5.1,9.0]  |[-0.419132264319932,0.419132264319932]      |[0.3967244102962335,0.6032755897037665] |1.0       |Iris-virginica |
|[5.9,3.0,5.1,1.8]|Iris-virginica |1.0         |[5.9,9.0,5.1,8.0]  |[-0.32958743896751885,0.32958743896751885]  |[0.4183410089972438,0.5816589910027563] |1.0       |Iris-virginica |
|[6.0,2.9,4.5,1.5]|Iris-versicolor|0.0         |[6.0,8.0,4.5,5.0]  |[0.14582457165756946,-0.14582457165756946]  |[0.5363916772629104,0.46360832273708963]|0.0       |Iris-versicolor|
|[6.1,3.0,4.6,1.4]|Iris-versicolor|0.0         |[6.1,9.0,4.6,4.0]  |[0.20090647458200817,-0.20090647458200817]  |[0.5500583546439539,0.4499416453560461] |0.0       |Iris-versicolor|
|[6.2,3.4,5.4,2.3]|Iris-virginica |1.0         |[6.2,13.0,5.4,13.0]|[-0.8807003330135101,0.8807003330135101]    |[0.29303267372325625,0.7069673262767437]|1.0       |Iris-virginica |
|[6.7,3.1,4.7,1.5]|Iris-versicolor|0.0         |[6.7,10.0,4.7,5.0] |[0.07689872680162013,-0.07689872680162013]  |[0.5192152136737482,0.48078478632625177]|0.0       |Iris-versicolor|
|[6.7,3.3,5.7,2.5]|Iris-virginica |1.0         |[6.7,12.0,5.7,15.0]|[-1.163178751002261,1.163178751002261]      |[0.23809016943453823,0.7619098305654617]|1.0       |Iris-virginica |
|[6.8,3.0,5.5,2.1]|Iris-virginica |1.0         |[6.8,9.0,5.5,11.0] |[-0.7360736047366578,0.7360736047366578]    |[0.32386333429517283,0.6761366657048272]|1.0       |Iris-virginica |
|[6.9,3.1,5.4,2.1]|Iris-virginica |1.0         |[6.9,10.0,5.4,11.0]|[-0.7016106823086834,0.7016106823086834]    |[0.33145521561995817,0.6685447843800418]|1.0       |Iris-virginica |
|[7.2,3.6,6.1,2.5]|Iris-virginica |1.0         |[7.2,14.0,6.1,15.0]|[-1.3010304407141597,1.3010304407141597]    |[0.21399164655179387,0.7860083534482062]|1.0       |Iris-virginica |
|[7.7,2.8,6.7,2.0]|Iris-virginica |1.0         |[7.7,7.0,6.7,10.0] |[-1.0600838485199424,1.0600838485199424]    |[0.2572934314622856,0.7427065685377143] |1.0       |Iris-virginica |
|[7.7,3.0,6.1,2.3]|Iris-virginica |1.0         |[7.7,9.0,6.1,13.0] |[-1.1219407900093334,1.1219407900093334]    |[0.24565146441425778,0.7543485355857422]|1.0       |Iris-virginica |
|[7.9,3.8,6.4,2.0]|Iris-virginica |1.0         |[7.9,15.0,6.4,10.0]|[-0.9566950812360182,0.9566950812360182]    |[0.2775403823663211,0.7224596176336789] |1.0       |Iris-virginica |
+-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+

Test Error = 0.03314285714285714

猜你喜欢

转载自www.cnblogs.com/asker009/p/12176982.html