Spark ML LR 用 setWeightCol 解决数据不平衡

前言

如题,记录在Spark ML LR中如何解决数据不平衡。参考:Dealing with unbalanced datasets in Spark MLlib

1、数据不平衡

指label == 1和label == 0 的数据比例的很多,如80%和20%,这样导致模型的结果的准确率也不平衡,不准确。

2、setWeightCol 主要代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
val labelCol = "label"
def balanceDataset(dataset: DataFrame): DataFrame = {

    // Re-balancing (weighting) of records to be used in the logistic loss objective function
    val numNegatives = dataset.filter(dataset(labelCol) === 0).count
    val datasetSize = dataset.count
    val balancingRatio = (datasetSize - numNegatives).toDouble / datasetSize

    val calculateWeights = udf { d: Double =>
      if (d == 0.0) {
        1 * balancingRatio
      } else {
        (1 * (1.0 - balancingRatio))
      }
    }

    val weightedDataset = dataset.withColumn("classWeightCol", calculateWeights(dataset(labelCol)))
    weightedDataset
  }

  val df_weighted = balanceDataset(df)

  val lr = new LogisticRegression().setLabelCol(labelCol).setWeightCol("classWeightCol")

这样就很方便解决了数据不平衡的问题

3、其他方法

最开始不知道有setWeightCol这个方法,我是按下面的方法解决的,记录一下

下面假设label=0的数据大于label=1的数据

1
2
3
4
5
6
7
8
9
10
11
12
13
  /**
 * 将label = 0的随机抽样,使label=1数量和label=0的数量大致相同
 */
def sample(df: DataFrame): DataFrame = {
  val df0 = df.where(s"${labelCol}=0")
  val df1 = df.where(s"${labelCol}=1")
  val y0 = df0.count()
  val y1 = df1.count()
  val num = 1.0 * y1 / y0
  val df00 = df0.sample(false, num) //解决类别数据平衡性问题,对没有违约样本进行随机抽样

  df00.union(df1)
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
  /**
 * 是将label=1 的复制多份,使label=1数量和label=0的数量大致相同
 */
def copy(df: DataFrame): DataFrame = {
  var df_res = df
  val df1 = df.where(s"${labelCol}=1")
  val y0 = df.where(s"${labelCol}=0").count()
  val y1 = df1.count()
  val num = (y0 / y1).toInt - 1
  for (a <- 1 to num) {
    df_res = df_res.union(df1)
  }
  df_res
}

猜你喜欢

转载自blog.csdn.net/fly_time2012/article/details/110368578