SparkRDD之aggregate

Spark 文档中对 aggregate的函数定义如下:

def aggregate[U](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) 
=> U)(implicit arg0: ClassTag[U]): U

注释:

Aggregate the elements of each partition, and then the results for 
all the partitions, using given combine functions and a neutral 
"zero value". 
This function can return a different result type, U, 
than the type of this RDD, T. 
Thus, we need one operation for merging a T into an U 
and one operation for merging two U's, as in 
Scala.TraversableOnce. Both of these functions are allowed to 
modify and return their first argument instead of creating a new U 
to avoid memory allocation. 

aggregate函数首先对每个分区里面的元素进行聚合,然后用combine函数将每个分区的结果和初始值(zeroValue)进行combine操作。这个操作返回的类型不需要和RDD中元素类型一致,所以在使用 aggregate()时,需要提供我们期待的返回类型的初始值,然后通过一个函数把RDD中的元素累加起来??放入累加器?。考虑到每个节点是在本地进行累加的,最终还需要提供第二个函数来将累加器两两合并。

其中seqOp操作会聚合各分区中的元素,然后combOp操作会把所有分区的聚合结果再次聚合,两个操作的初始值都是zeroValueseqOp的操作是遍历分区中的所有元素(T),第一个T跟zeroValue做操作,结果再作为与第二个T做操作的zeroValue,直到遍历完整个分区。combOp操作是把各分区聚合的结果,再聚合。aggregate函数返回一个跟RDD不同类型的值。因此,需要一个操作seqOp来把分区中的元素T合并成一个U,另外一个操作combOp把所有U聚合。

aggregate聚合函数允许用户将两个不同的reduce函数应用于RDD。 在每个分区中应用第一个reduce函数,以将每个分区中的数据减少为单个结果。 第二减少函数用于将所有分区的不同减少的结果组合在一起以得到一个最终结果。 对于内部分区和跨分区减少具有两个单独的reduce函数的能力增加了很多灵活性。 例如,第一个reduce函数可以是max函数,第二个函数可以是sum函数。 用户还指定初始值。 这是一些重要的事实。

  • 初始值作用于两次reduce函数,在每个paritition内部reduce和跨partitionreduce
  • 两个reduce函数都必须是可交换的和关联的。
  • 不要假设分区计算或组合分区的任何执行顺序。

java示例如下:

package com.cb.spark.sparkrdd;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;

public class AggregateExample {
	public static void main(String[] args) {
		SparkConf conf = new SparkConf().setAppName("Aggregate").setMaster("local");
		JavaSparkContext jsc = new JavaSparkContext(conf);
		JavaRDD<Integer> rdd = jsc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 3);

		// 先max(),在sum()
		Integer aggregateValue = rdd.aggregate(5, new Function2<Integer, Integer, Integer>() {

			private static final long serialVersionUID = 1L;

			@Override
			public Integer call(Integer arg0, Integer arg1) throws Exception {
				return Math.max(arg0, arg1); 
			}
		}, new Function2<Integer, Integer, Integer>() {
			private static final long serialVersionUID = 1L;

			@Override
			public Integer call(Integer arg0, Integer arg1) throws Exception {
				return arg0 + arg1;
			}
		});
		System.out.println(aggregateValue);//21

		// 先sum(),在max()
		Integer aggregateValue1 = rdd.aggregate(5, (x1, x2) -> x1 + x2, (x1, x2) -> Math.max(x1, x2));
		System.out.println(aggregateValue1);//16
		jsc.stop();
	}
}

 scala示例代码:

package com.cb.spark.core

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import scala.collection.Iterator

object Aggregate {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
      .setAppName("Aggregate")
      .setMaster("local")
    val sc = new SparkContext(conf)
    val z = sc.parallelize(List("a", "b", "c", "d", "e", "f"), 2)

    def myfunc(index: Int, iter: Iterator[(String)]): Iterator[String] = {
      iter.map(x => "[partID:" + index + ",val:" + x + "]")
    }

    z.mapPartitionsWithIndex(myfunc, true).foreach(println)

    z.aggregate("m")(_ + _, _ + _).foreach(print) //mmabcmdef

    val z1 = sc.parallelize(List("12", "23", "345", "4567"), 2)
    var result = z1.aggregate("12+")((x, y) => math.max(x.length(), y.length()).toString(), (x, y) => x + y)
    println(result) //"24"

    //第一个partition:
    //1.math.min("12+".length,"12".length()).toString()="2"
    //2.math.min("2".length(),"23".length()).toString()="1"
    //第二个partition:
    //3.math.min("12+".length,"345".length()).toString()="3"
    //4.math.min("3".length,"4567".length()).toString()="1"
    result = z1.aggregate("12+")((x, y) => math.min(x.length, y.length).toString, (x, y) => x + y)
    println(result) //"11"

    val z2 = sc.parallelize(List("12", "23", "345", ""), 2)
    result = z2.aggregate("12+")((x, y) => math.min(x.length(), y.length()).toString, (x, y) => x + y)
    println(result) //10
    sc.stop()
  }
}

猜你喜欢

转载自blog.csdn.net/u013230189/article/details/81629241