Spark2.3 RDD treeAggregate / treeReduce source code analysis

As mentioned earlier: reduce/aggregate operations are expensive, while treeReduce/treeAggregate can control the scale of each reduce by adjusting the depth.

treeReduce source code:

/**
   * Reduces the elements of this RDD in a multi-level tree pattern.
   *
   * @param depth suggested depth of the tree (default: 2)
   * @see [[org.apache.spark.rdd.RDD # reduce]]
   */
  def treeReduce(f: (T, T) => T, depth: Int = 2): T = withScope {
    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
    val cleanF = context.clean(f)
    val reducePartition: Iterator[T] => Option[T] = iter => {
      if (iter.hasNext) {
        Some(iter.reduceLeft(cleanF))
      } else {
        None
      }
    }
    //return new RDD
    val partiallyReduced: RDD[Option[T]] = mapPartitions(it => Iterator(reducePartition(it)))
    //return a (Option[T], Option[T]) => Option[T] function
    val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
      if (c.isDefined && x.isDefined) {
        Some(cleanF(c.get, x.get))
      } else if (c.isDefined) {
        c
      } else if (x.isDefined) {
        x
      } else {
        None
      }
    }
    partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth).getOrElse(throw new UnsupportedOperationException("empty collection"))
  }

treeAggregate source code:

/**
   * Aggregates the elements of this RDD in a multi-level tree pattern.
   * This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]].
   *
   * @param depth suggested depth of the tree (default: 2)
   */
  def treeAggregate[U: ClassTag](zeroValue: U)(
      seqOp: (U, T) => U,
      combOp: (U, U) => U,
      depth: Int = 2): U = withScope {
    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
    if (partitions.length == 0) {
      Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
    } else {
      val cleanSeqOp = context.clean(seqOp)
      val cleanCombOp = context.clean(combOp)
      val aggregatePartition =
        (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
      var partiallyAggregated: RDD[U] = mapPartitions(it => Iterator(aggregatePartition(it)))
      var numPartitions = partiallyAggregated.partitions.length
      val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
      // If creating an extra level doesn't help reduce
      // the wall-clock time, we stop tree aggregation.

      // Don't trigger TreeAggregation when it doesn't save wall-clock time
      while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
        numPartitions /= scale
        val curNumPartitions = numPartitions
        partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
          (i, iter) => iter.map((i % curNumPartitions, _))
        }.foldByKey(zeroValue, new HashPartitioner(curNumPartitions))(cleanCombOp).values
      }
      val copiedZeroValue = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
      partiallyAggregated.fold(copiedZeroValue)(cleanCombOp)
    }
  }

Comparison of treeAggregate and aggregate
object TreeAggregateTest extends App {

  select sparkConf = new SparkConf ().
    setAppName("TreeAggregateTest")
    .setMaster("local[6]")

  val spark = SparkSession
    .builder()
    .config(sparkConf)
    .getOrCreate()

  val value: RDD[Int] = spark.sparkContext.parallelize(List(1, 2, 3, 5, 8, 9), 3)

  val treeAggregateResult: Int = value.treeAggregate(4)((a, b) => {
    math.min (a, b)
  }, (a, b) => {
    println(a + "+" + b)
    a + b
  }, 2)
  println("treeAggregateResult:" + treeAggregateResult)

  val aggregateResult: Int = value.aggregate(4)((a, b) => {
    math.min (a, b)
  }, (a, b) => {
    println(a + "+" + b)
    a + b
  })
  println("aggregateResult:" + aggregateResult)


}

Compared with treeAggregate, aggregate will call the initial value once and treeAggregate will not call the combOp phase when the results of multiple partitions are merged. Both seqOp are called on each partition.
My understanding is that treeAggregate sends the combOp to be sent to the diver side for execution. When the number of partitions is too large, after re-partitioning according to i % curNumPartitions, it means that according to depth
To perform combOp several times, and finally return the result to the diver side and then perform the combOp operation, which can reduce the pressure on the diver side and reduce the risk of OOM.
Using treeAggregate in actual operation can be more flexible. Compared with the previous treeAggregate in spark2.3, the 2.3 version uses foldByKey to merge the partition results.

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325609587&siteId=291194637