Spark source code analysis-partitioning when reading local collections

1. The number of partitions of the RDD created by the collection

// 源码分析之:通过集合创建的RDD默认分区数
val rdd = sc.parallelize(list)

// 1、查看parallelize的源码,传入两个参数:1、集合,2、片数,返回一个RDD
 def parallelize[T: ClassTag](
      seq: Seq[T],
      numSlices: Int = defaultParallelism): RDD[T]   //切片数等于defaultParallelism
	  
//2、查看defaultParallelism的源码,调用了taskScheduler下的defaultParallelism
  def defaultParallelism: Int = {
    
    
    assertNotStopped()
    taskScheduler.defaultParallelism
  }

// 3、进入taskScheduler中查看defaultParallelism,是tail类型,看唯一的一个实现类TaskSchedulerImpl
override def defaultParallelism(): Int = backend.defaultParallelism()

// 4、又调用的是在backend下的一个defaultParallelism,疯了,又是一个tail,在需要看实现类
// 有两个实现类:1、LocalSchedulerBackend(本地模式下的)
//			     2、CoarseGrainedSchedulerBackend(集群模式下)
def defaultParallelism(): Int

// 5、查看本地模式下的defaultParallelism的实现方法
override def defaultParallelism(): Int =
    scheduler.conf.getInt("spark.default.parallelism", totalCores)
	
	//进入getInt,发现将其封装为Option格式,如果key中设置了"spark.default.parallelism",则使用设置的值,如果getOption中返回的是None,则使用默认值defaultValue,即第5步中的totalCores
	def getInt(key: String, defaultValue: Int): Int = catchIllegalValue(key) {
    
    
    getOption(key).map(_.toInt).getOrElse(defaultValue)
	}
	
		// 进入map方法中发现果然,如果空,返回None,否则封装为Some返回
		final def map[B](f: A => B): Option[B] =
				if (isEmpty) None else Some(f(this.get))
	
// 6、进入totalCores中,totalCores在构造器中定义,需要查看new LocalSchedulerBackend中传入的参数
private[spark] class LocalSchedulerBackend(
    conf: SparkConf,
    scheduler: TaskSchedulerImpl,
    val totalCores: Int)
	
	
// 7、进入SparkContext的源码中,在createTaskScheduler方法中找到new LocalSchedulerBackend对象,发现在模式匹配中

	//一、如果输入的是local
	case "local" =>
        checkResourcesPerTask(clusterMode = false, Some(1))
        val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
        val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1)
        scheduler.initialize(backend)
        (backend, scheduler)
	LocalSchedulerBackend

//8、如果输入的是local[N],正则匹配	,通过源码发现,new LocalSchedulerBackend中传入totalCores的是threadCount,
	//如果输入的是*,返回总的CPU核数localCpuCount,
	//否则返回输入的N.toInt

	   case LOCAL_N_REGEX(threads) =>
        def localCpuCount: Int = Runtime.getRuntime.availableProcessors()
        // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
        val threadCount = if (threads == "*") localCpuCount else threads.toInt
        if (threadCount <= 0) {
    
    
          throw new SparkException(s"Asked to run locally with $threadCount threads")
        }
        checkResourcesPerTask(clusterMode = false, Some(threadCount))
        val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
        val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
        scheduler.initialize(backend)
        (backend, scheduler)
	   
		//进入LOCAL_N_REGEX中,正则匹配传入的参数N
			private object SparkMasterRegex {
    
    
			  // Regular expression used for local[N] and local[*] master formats
			  val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
			  // Regular expression for local[N, maxRetries], used in tests with failing tasks
			  val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
			  // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
			  val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
			  // Regular expression for connecting to Spark deploy clusters
			  val SPARK_REGEX = """spark://(.*)""".r
			}
	

Second, the principle of zoning

// 源码分析之:查看RDD是如何进行分区的,

//1、进入parallelize源码中
val rdd = sc.parallelize(list)

//2、在定义的parallelize方法中,创建了一个ParallelCollectionRDDu对象,进入ParallelCollectionRDD中
def parallelize[T: ClassTag](
      seq: Seq[T],
      numSlices: Int = defaultParallelism): RDD[T] = withScope {
    
    
    assertNotStopped()
    new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
  }
  
  
// 3、在ParallelCollectionRDD中,发现有getPartitions的方法,
//    该方法调用slice(slice在scala中是创建子集合的,所以查看slice方法)
override def getPartitions: Array[Partition] = {
    
    
    val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
    slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
  }
  

// 4、重点分析该slice方法

/**
  *重点看匹配
  *  1、如果r是Range类型,
  *  2、NumericRange类型,list也不是该类型的子类
  *  3、其他都执行case _
  *
  *
  */

  def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
    
    
  
  //numSlices数量小于1,抛异常
    if (numSlices < 1) {
    
    
      throw new IllegalArgumentException("Positive number of partitions required")
    }
    // Sequences need to be sliced at the same set of index positions for operations
    // like RDD.zip() to behave as expected
	
	/*
		该方法将生成一个一个元组的迭代器,
		例如:
				如果list是(0 1 2 3 4)  -> 切成3片
				0--1
				1--3
				3--5
	*/
    def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
    
    
      (0 until numSlices).iterator.map {
    
     i =>
        val start = ((i * length) / numSlices).toInt
        val end = (((i + 1) * length) / numSlices).toInt
        (start, end)
      }
    }
    seq match {
    
    
      case r: Range =>
        positions(r.length, numSlices).zipWithIndex.map {
    
     case ((start, end), index) =>
          // If the range is inclusive, use inclusive range for the last slice
          if (r.isInclusive && index == numSlices - 1) {
    
    
            new Range.Inclusive(r.start + start * r.step, r.end, r.step)
          }
          else {
    
    
            new Range(r.start + start * r.step, r.start + end * r.step, r.step)
          }
        }.toSeq.asInstanceOf[Seq[Seq[T]]]
      case nr: NumericRange[_] =>
        // For ranges of Long, Double, BigInteger, etc
        val slices = new ArrayBuffer[Seq[T]](numSlices)
        var r = nr
        for ((start, end) <- positions(nr.length, numSlices)) {
    
    
          val sliceSize = end - start
          slices += r.take(sliceSize).asInstanceOf[Seq[T]]
          r = r.drop(sliceSize)
        }
        slices
		
		//先转化为数组,调用positions方法,传入数组长度和切片大小
      case _ =>
        val array = seq.toArray // To prevent O(n^2) operations for List etc
        positions(array.length, numSlices).map {
    
     case (start, end) =>
            array.slice(start, end).toSeq
        }.toSeq    //又将数组转化为集合的形式返回
    }
  }
  
//5、  执行完之后接着执行第三步中的方法
  slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray 
  
		//查看indices的定义,为分区数,即【0,1,2...length-1】
			def indices: Range = 0 until length
		//.map创建一个分区,将数据封装到分区中	
			 

Guess you like

Origin blog.csdn.net/qq_38705144/article/details/113182486