前言
如题,记录在Spark ML LR中如何解决数据不平衡。参考:Dealing with unbalanced datasets in Spark MLlib
1、数据不平衡
指label == 1和label == 0 的数据比例的很多,如80%和20%,这样导致模型的结果的准确率也不平衡,不准确。
2、setWeightCol 主要代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
val labelCol = "label" def balanceDataset(dataset: DataFrame): DataFrame = { // Re-balancing (weighting) of records to be used in the logistic loss objective function val numNegatives = dataset.filter(dataset(labelCol) === 0).count val datasetSize = dataset.count val balancingRatio = (datasetSize - numNegatives).toDouble / datasetSize val calculateWeights = udf { d: Double => if (d == 0.0) { 1 * balancingRatio } else { (1 * (1.0 - balancingRatio)) } } val weightedDataset = dataset.withColumn("classWeightCol", calculateWeights(dataset(labelCol))) weightedDataset } val df_weighted = balanceDataset(df) val lr = new LogisticRegression().setLabelCol(labelCol).setWeightCol("classWeightCol") |
这样就很方便解决了数据不平衡的问题
3、其他方法
最开始不知道有setWeightCol这个方法,我是按下面的方法解决的,记录一下
下面假设label=0的数据大于label=1的数据
1 2 3 4 5 6 7 8 9 10 11 12 13 |
/** * 将label = 0的随机抽样,使label=1数量和label=0的数量大致相同 */ def sample(df: DataFrame): DataFrame = { val df0 = df.where(s"${labelCol}=0") val df1 = df.where(s"${labelCol}=1") val y0 = df0.count() val y1 = df1.count() val num = 1.0 * y1 / y0 val df00 = df0.sample(false, num) //解决类别数据平衡性问题,对没有违约样本进行随机抽样 df00.union(df1) } |
或
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
/** * 是将label=1 的复制多份,使label=1数量和label=0的数量大致相同 */ def copy(df: DataFrame): DataFrame = { var df_res = df val df1 = df.where(s"${labelCol}=1") val y0 = df.where(s"${labelCol}=0").count() val y1 = df1.count() val num = (y0 / y1).toInt - 1 for (a <- 1 to num) { df_res = df_res.union(df1) } df_res } |