spark源码分析之dagscheduler原理篇

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/weixin_39478115/article/details/79328891

这里写图片描述
一个action操作触发runjob方法,然后一步一步的调用runjob,一直调用到dagScheduler的runjob方法
解释:
1、首先会创建一个hadoopRDD,然后将hadoopRDD变成一个MapppedRDD
2、创建一个FlatMappedRDD
3、创建一个MappedRDD
4、reduceByKey这个算子,首先会产生MapPartitionsRDD,然后是ShuffleRDD,再然后是MapPartitionsRDD

  • MapPartitionsRDD是本地数据聚合的rdd,也就是代表本地文件的rdd,使用HashPartitioner,对key进行部分整合,保存到多个paritition中,也就是对应的不同文件中
  • ShuffleRDD对MapPartitionsRDD对相同的key保存到一个partition中
  • MapPartitionsRDD对key进行聚合操作

5、执行到foreach这个action操作的时候,就会通过SparkContext的runJob()去触发job(DAGScheduler)

总结:
DAGScheculer划分stage的算法,会从触发最后一个触发action操作的那个rdd向前倒推,首先为最后一个rdd创建一个stage,过程中,如果发现某一个rdd是宽依赖,就将该rdd创建一个新的stage,那个rdd就是新的stage的最后一个rdd,依次类推,直到所有rdd都遍历完成。

源码分析:
第一步:点击runjob方法
源码位置:org/apache/spark/SparkContext.scala

  /**
   * Run a job on all partitions in an RDD and return the results in an array.
   */
  def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
    runJob(rdd, func, 0 until rdd.partitions.length)
  }

第二步:点击第一步runjob方法
源码位置:org/apache/spark/SparkContext.scala

  /**
   * Run a job on a given set of partitions of an RDD, but take a function of type
   * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
   */
  def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: Iterator[T] => U,
      partitions: Seq[Int]): Array[U] = {
    val cleanedFunc = clean(func)
    runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions)
  }

第三步:点击第二步runjob方法
源码位置:org/apache/spark/SparkContext.scala

  /**
   * Run a function on a given set of partitions in an RDD and return the results as an array.
   */
  def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int]): Array[U] = {
    val results = new Array[U](partitions.size)
    runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res)
    results
  }

第四步:点击第三步runjob方法
源码位置:org/apache/spark/SparkContext.scala

  /**
   * Run a function on a given set of partitions in an RDD and pass the results to the given
   * handler function. This is the main entry point for all actions in Spark.
   */
  def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      resultHandler: (Int, U) => Unit): Unit = {
    if (stopped.get()) {
      throw new IllegalStateException("SparkContext has been shutdown")
    }
    val callSite = getCallSite
    val cleanedFunc = clean(func)
    logInfo("Starting job: " + callSite.shortForm)
    if (conf.getBoolean("spark.logLineage", false)) {
      logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
    }


    // 调用DAGScheduler的runJob()方法,rdd为reduceByKey这个算子
    dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)

    progressBar.foreach(_.finishAll())
    rdd.doCheckpoint()
  }

第五步:点击第四步runjob方法
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  def runJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      resultHandler: (Int, U) => Unit,
      properties: Properties): Unit = {
    val start = System.nanoTime

    // 提交job
    val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)

    // 判断job的执行结果
    waiter.awaitResult() match {
      case JobSucceeded =>
        logInfo("Job %d finished: %s, took %f s".format
          (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
      case JobFailed(exception: Exception) =>
        logInfo("Job %d failed: %s, took %f s".format
          (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
        // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
        val callerStackTrace = Thread.currentThread().getStackTrace.tail
        exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
        throw exception
    }
  }

第六步:点击第五步submitJob方法
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  def submitJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      resultHandler: (Int, U) => Unit,
      properties: Properties): JobWaiter[U] = {
    // Check to make sure we are not launching a task on a partition that does not exist.
    val maxPartitions = rdd.partitions.length
    partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
      throw new IllegalArgumentException(
        "Attempting to access a non-existent partition: " + p + ". " +
          "Total number of partitions: " + maxPartitions)
    }

    val jobId = nextJobId.getAndIncrement()
    if (partitions.size == 0) {
      // Return immediately if the job is running 0 tasks
      return new JobWaiter[U](this, jobId, 0, resultHandler)
    }

    assert(partitions.size > 0)
    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
    val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)

    // DAGSchedulerEventProcessLoop的doOnReceive方法创建了DAGScheduler的job到核心入口
    eventProcessLoop.post(JobSubmitted(
      jobId, rdd, func2, partitions.toArray, callSite, waiter,
      SerializationUtils.clone(properties)))
    waiter
  }

第七步:点击第六步eventProcessLoop方法
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)

第八步:点击第七步DAGSchedulerEventProcessLoop
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {

      // 最终调用了handleJobSubmitted
    case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
      dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)

第九步:点击第八步handleJobSubmitted
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  /**
   * DAGScheduler的job调度的核心入口
   *
   * stage划分算法非常重要:
   *    必须要对stage算法很清晰,直到你自己编写的spark application被划分为了几个job,
   *    每个job被划分为了几个stage,每个stage包含了你那些代码,只有知道了每个stage包括了
   *    你那些代码之后,在线上,如果你发现某一个stage执行特别慢,或者是某一个stage一直报错
   *    你才能针对哪个stage对应的代码,去排查问题,或者说是性能调优
   *
   *  stage划分算法总结:
   *    1、从finalStage倒推
   *    2、通过宽依赖,来进行新的stage的划分
   *    3、使用递归,优先提交父stage
   */
  private[scheduler] def handleJobSubmitted(jobId: Int,
      finalRDD: RDD[_],
      func: (TaskContext, Iterator[_]) => _,
      partitions: Array[Int],
      callSite: CallSite,
      listener: JobListener,
      properties: Properties) {

    // 第一步:使用触发job的最后一个rdd,创建finalStage
    var finalStage: ResultStage = null
    try {
      // New stage creation may throw an exception if, for example, jobs are run on a
      // HadoopRDD whose underlying HDFS files have been deleted.

      // 创建一个新的stage,并且将stage加入DAGScheduler内部的内存缓存中
      finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)
    } catch {
      case e: Exception =>
        logWarning("Creating new stage failed due to exception - job: " + jobId, e)
        listener.jobFailed(e)
        return
    }

    // 第二步:使用finalStage,创建一个job,也就是这个job的最后一个stage,就是finalstage
    val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
    clearCacheLocs()
    logInfo("Got job %s (%s) with %d output partitions".format(
      job.jobId, callSite.shortForm, partitions.length))
    logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
    logInfo("Parents of final stage: " + finalStage.parents)
    logInfo("Missing parents: " + getMissingParentStages(finalStage))

    val jobSubmissionTime = clock.getTimeMillis()

    // 第三步:将job加入到内存缓存中
    jobIdToActiveJob(jobId) = job
    activeJobs += job
    finalStage.setActiveJob(job)
    val stageIds = jobIdToStageIds(jobId).toArray
    val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
    listenerBus.post(
      SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))

    // 第四步:使用submitStage()提交finalStage
    // 这个方法的调用,其实会导致第一个stage提交,并且导致其他所有的
    // stage,都给放入waitingStages队列中
    submitStage(finalStage)

    // 提交等待的stage
    submitWaitingStages()
  }

第十步:点击第九步newResultStage
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  /**
   * Create a ResultStage associated with the provided jobId.
   */
  private def newResultStage(
      rdd: RDD[_],
      func: (TaskContext, Iterator[_]) => _,
      partitions: Array[Int],
      jobId: Int,
      callSite: CallSite): ResultStage = {
    val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
    // 创建ResultStage
    val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite)
    stageIdToStage(id) = stage
    updateJobIdStageIdMaps(jobId, stage)
    stage
  }

第十一步:点击第十步newResultStage
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

private[spark] class ResultStage(
    id: Int,
    rdd: RDD[_],
    val func: (TaskContext, Iterator[_]) => _,
    val partitions: Array[Int],
    parents: List[Stage],
    firstJobId: Int,
    callSite: CallSite)
  extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) { 
  ......
}

第十二步:点击第九步submitStage
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  /** Submits stage, but first recursively submits any missing parents. */
  /**
   * 提交stage的方法
   * 这个是stage划分算法的入口,但是stage划分算法,其实是由submitStage()方法和getMissingParentStages()方法共同组成的
   */
  private def submitStage(stage: Stage) {
    val jobId = activeJobForStage(stage)
    if (jobId.isDefined) {
      logDebug("submitStage(" + stage + ")")
      if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {

        // 调用getMissingParentStages方法,获取当前这个stage的父stage
        val missing = getMissingParentStages(stage).sortBy(_.id)

        logDebug("missing: " + missing)

        /**
         * 总结来说,这里其实会反复的递归调用,直到最初的stage,它没有父stage,
         * 那么,此时,就会首先会提交stage0,然后其余的stage,全部都在waitingStages里边
         */

        // 如果这个stage没有父stage的情况下
        if (missing.isEmpty) {
          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
          // 提交DAG,会创建task
          submitMissingTasks(stage, jobId.get)
        } else {
          for (parent <- missing) {
            // 递归调用submit()方法,去提交父stage
            // 这里的递归,就是stage划分算法的推动者和精髓
            submitStage(parent)
          }

          // 并且将当前stage,放入到waitingStages等待执行的stage队列中
          waitingStages += stage
        }
      }
    } else {
      abortStage(stage, "No active job for stage " + stage.id, None)
    }
  }

第十三步:点击第十二步getMissingParentStages
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  /**
   * 获取某个stage的父stage
   *
   * 总体来说:对一个stage,如果它的最后一个rdd的所有依赖,都是窄依赖,那么
   * 它不会创建任何的一个stage,但是只要发现这个stage的rdd宽依赖了某个rdd,
   * 那就用宽依赖的那个rdd,创建一个新的stage,然后立即将新的stage返回
   */
  private def getMissingParentStages(stage: Stage): List[Stage] = {
    val missing = new HashSet[Stage]
    val visited = new HashSet[RDD[_]]
    // We are manually maintaining a stack here to prevent StackOverflowError
    // caused by recursively visiting
    val waitingForVisit = new Stack[RDD[_]]


    def visit(rdd: RDD[_]) {
      if (!visited(rdd)) {
        visited += rdd
        val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
        if (rddHasUncachedPartitions) {
          // 遍历rdd的依赖
          // 所以说,针对我们之前的那个图,其实对于每一种都有shuffle的操作,比如groupByKey、
          // reduceByKey,countByKey,底层对应了三个RDD:MapPartitonsRDD、ShuffleRDD、MapPartitonsRDD
          for (dep <- rdd.dependencies) {
            dep match {
              // 如果是宽依赖
              case shufDep: ShuffleDependency[_, _, _] =>
                // 那么就用宽依赖的那个rdd,创建一个Stage,并将isShuffleMap设置为true
                // 默认最后stage,不是ShuffleMapStage
                // 但是finalStage之前的所有stage,都是shuffleMap stage
                val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
                if (!mapStage.isAvailable) {
                  missing += mapStage
                }

              // 如果是窄依赖,那么将依赖的rdd放入到栈中
              case narrowDep: NarrowDependency[_] =>
                waitingForVisit.push(narrowDep.rdd)
            }
          }
        }
      }
    }

    // 首先向栈中,推入了stage的最后一个rdd
    waitingForVisit.push(stage.rdd)

    // 然后进行while循环
    while (waitingForVisit.nonEmpty) {
      // 对stage的最后一个rdd,调用了自己内部定义的visit()方法
      visit(waitingForVisit.pop())
    }
    missing.toList
  }

第十四步:点击第十三步getShuffleMapStage
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  /**
   * Get or create a shuffle map stage for the given shuffle dependency's map side.
   */
  private def getShuffleMapStage(
      shuffleDep: ShuffleDependency[_, _, _],
      firstJobId: Int): ShuffleMapStage = {
    shuffleToMapStage.get(shuffleDep.shuffleId) match {
      case Some(stage) => stage
      case None =>
        // We are going to register ancestor shuffle dependencies
        getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
          shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId)
        }

        // Then register current shuffleDep
        // todo
        val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
        shuffleToMapStage(shuffleDep.shuffleId) = stage
        stage
    }
  }

第十五步:点击第十四步newOrUsedShuffleStage
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  private def newOrUsedShuffleStage(
      shuffleDep: ShuffleDependency[_, _, _],
      firstJobId: Int): ShuffleMapStage = {
    val rdd = shuffleDep.rdd
    val numTasks = rdd.partitions.length
    val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite)
    if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
      val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
      val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
      (0 until locs.length).foreach { i =>
        if (locs(i) ne null) {
          // locs(i) will be null if missing
          stage.addOutputLoc(i, locs(i))
        }
      }
    } else {
      // Kind of ugly: need to register RDDs with the cache and map output tracker here
      // since we can't do it in the RDD constructor because # of partitions is unknown
      logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
      mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)
    }
    stage
  }

第十六步:点击第十五步newShuffleMapStage
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  private def newShuffleMapStage(
      rdd: RDD[_],
      numTasks: Int,
      // 创建一个stage的时候,就将它设置成为宽依赖
      shuffleDep: ShuffleDependency[_, _, _],
      firstJobId: Int,
      callSite: CallSite): ShuffleMapStage = {
    val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId)
    val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages,
      firstJobId, callSite, shuffleDep)

    stageIdToStage(id) = stage
    updateJobIdStageIdMaps(firstJobId, stage)
    stage
  }

第十七步:点击第十二步submitMissingTasks
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

/** Called when stage's parents are available and we can now do its task. */
  /**
   * 提交stage,为stage创建一批task,task的数量和partition的数量相同
   */
  private def submitMissingTasks(stage: Stage, jobId: Int) {
    logDebug("submitMissingTasks(" + stage + ")")
    // Get our pending tasks and remember them in our pendingTasks entry
    stage.pendingPartitions.clear()

    // First figure out the indexes of partition ids to compute.
    // 计算出要创建多少个partition
    val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()

    // Create internal accumulators if the stage has no accumulators initialized.
    // Reset internal accumulators only if this stage is not partially submitted
    // Otherwise, we may override existing accumulator values from some tasks
    if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) {
      stage.resetInternalAccumulators()
    }

    // Use the scheduling pool, job group, description, etc. from an ActiveJob associated
    // with this Stage
    val properties = jobIdToActiveJob(jobId).properties

    // 将stage加入到runningStages队列
    runningStages += stage
    // SparkListenerStageSubmitted should be posted before testing whether tasks are
    // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
    // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
    // event.
    stage match {
      case s: ShuffleMapStage =>
        outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1)
      case s: ResultStage =>
        outputCommitCoordinator.stageStart(
          stage = s.id, maxPartitionId = s.rdd.partitions.length - 1)
    }
    val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        case s: ShuffleMapStage =>
          partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        case s: ResultStage =>
          val job = s.activeJob.get
          partitionsToCompute.map { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
          }.toMap
      }
    } catch {
      case NonFatal(e) =>
        stage.makeNewStageAttempt(partitionsToCompute.size)
        listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
        abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
        runningStages -= stage
        return
    }

    stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
    listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))

    // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
    // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
    // the serialized copy of the RDD and for each task we will deserialize it, which means each
    // task gets a different copy of the RDD. This provides stronger isolation between tasks that
    // might modify state of objects referenced in their closures. This is necessary in Hadoop
    // where the JobConf/Configuration object is not thread-safe.
    var taskBinary: Broadcast[Array[Byte]] = null
    try {
      // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
      // For ResultTask, serialize and broadcast (rdd, func).
      val taskBinaryBytes: Array[Byte] = stage match {
        case stage: ShuffleMapStage =>
          closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array()
        case stage: ResultStage =>
          closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array()
      }

      taskBinary = sc.broadcast(taskBinaryBytes)
    } catch {
      // In the case of a failure during serialization, abort the stage.
      case e: NotSerializableException =>
        abortStage(stage, "Task not serializable: " + e.toString, Some(e))
        runningStages -= stage

        // Abort execution
        return
      case NonFatal(e) =>
        abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e))
        runningStages -= stage
        return
    }

    // 为stage创建指定数量的task
    // 这里最关键的是,task的最佳位置计算算法
    val tasks: Seq[Task[_]] = try {
      stage match {
        case stage: ShuffleMapStage =>
          partitionsToCompute.map { id =>
            // 给每一个partition创建一个task
            // 给每一个task计算最佳位置
            val locs = taskIdToLocations(id)
            val part = stage.rdd.partitions(id)
            // 然后对于finalStage之外的stage,它的isShuffleMap都是true
            // 所以会创建ShuffleMapTask
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, stage.internalAccumulators)
          }

          // 如果不是shuffleMap,那么就是finalStage
          // finalStage,是创建ResultTask的
        case stage: ResultStage =>
          val job = stage.activeJob.get
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = stage.rdd.partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, id, stage.internalAccumulators)
          }
      }
    } catch {
      case NonFatal(e) =>
        abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
        runningStages -= stage
        return
    }

    if (tasks.size > 0) {
      logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
      stage.pendingPartitions ++= tasks.map(_.partitionId)
      logDebug("New pending partitions: " + stage.pendingPartitions)

      // 最后,针对stage的task,创建taskset对象,调用TaskScheduler的submitTasks方法,提交TaskSet
      taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
      stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
    } else {
      // Because we posted SparkListenerStageSubmitted earlier, we should mark
      // the stage as completed here in case there are no tasks to run
      markStageAsFinished(stage, None)

      val debugString = stage match {
        case stage: ShuffleMapStage =>
          s"Stage ${stage} is actually done; " +
            s"(available: ${stage.isAvailable}," +
            s"available outputs: ${stage.numAvailableOutputs}," +
            s"partitions: ${stage.numPartitions})"
        case stage : ResultStage =>
          s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
      }
      logDebug(debugString)
    }
  }

第十八步:点击第十七步getPreferredLocs
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  private[spark]
  def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
    getPreferredLocsInternal(rdd, partition, new HashSet)
  }

第十九步:点击第十八步getPreferredLocsInternal
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

 /**
   * 计算每个task对应的partition的最佳位置
   * 说白了,就是从stage的最后一个rdd开始,去找那个rdd的partition是被cache,或者是checkpoint
   * 那么,task的最佳位置,就是缓存的paritition位置或者是checkpoint缓存的位置,因为这样的话,
   * task就在那个节点上执行,不用计算之前的rdd了
   */
  private def getPreferredLocsInternal(
      rdd: RDD[_],
      partition: Int,
      visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = {
    // If the partition has already been visited, no need to re-visit.
    // This avoids exponential path exploration.  SPARK-695
    if (!visited.add((rdd, partition))) {
      // Nil has already been returned for previously visited partitions.
      return Nil
    }

    // If the partition is cached, return the cache locations
    // 寻找当前rdd的partition是否缓存了
    val cached = getCacheLocs(rdd)(partition)
    if (cached.nonEmpty) {
      return cached
    }

    // If the RDD has some placement preferences (as is the case for input RDDs), get those
    // 需要当前rdd是否checkpoint了
    val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
    if (rddPrefs.nonEmpty) {
      return rddPrefs.map(TaskLocation(_))
    }

    // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency
    // that has any placement preferences. Ideally we would choose based on transfer sizes,
    // but this will do for now.
    // 最后,递归调用自己,去寻找rdd的父rdd,看看对应的partition是否缓存或者是checkpoint
    rdd.dependencies.foreach {
      case n: NarrowDependency[_] =>
        for (inPart <- n.getParents(partition)) {
          val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
          if (locs != Nil) {
            return locs
          }
        }

      case _ =>
    }

    // 如果这个stage,从最后一个stage,到最开始的rdd,partition都没有被缓存或者是checkpoint
    // 那么,task的最佳位置(preferredLocs),就是Nil
    Nil
  }

第二十步:点击第十九步preferredLocations
源码位置:org/apache/spark/scheduler/DAGScheduler.scala

  /**
   * Get the preferred locations of a partition, taking into account whether the
   * RDD is checkpointed.
   */
  final def preferredLocations(split: Partition): Seq[String] = {
    checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
      getPreferredLocations(split)
    }
  }

猜你喜欢

转载自blog.csdn.net/weixin_39478115/article/details/79328891