Spark 执行流程

本文结合源代码和《图解Spark核心技术与案例实战》简单分析了Spark的job执行过程。
分析的案例代码如下,是一个简单的word count 函数。

public static void main( String[] args )
    {
        String logFile = "wordcount.txt"; 
        SparkConf conf = new SparkConf().setAppName("Simple Application").setMaster("local[*]");
        JavaSparkContext sc = new JavaSparkContext(conf);
        JavaRDD<String> logData = sc.textFile(logFile).cache();

        JavaRDD<String> words = logData.flatMap(new FlatMapFunction<String, String>() {
            @Override
            public Iterable<String> call(String s) throws Exception {
                return Arrays.asList(s.split(" "));
            }
        });
        JavaPairRDD<String,Integer> wordPair = words.mapToPair(new PairFunction<String, String, Integer>() {
            @Override
            public Tuple2<String, Integer> call(String s) throws Exception {
                return new Tuple2<>(s,1);
            }
        });

        JavaPairRDD<String,Integer> wordCount = wordPair.reduceByKey(new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception {
                return integer+integer2;
            }
        });

        // 触发操作
        List<Tuple2<String,Integer>> counts= wordCount.collect();
        for (Tuple2<String, Integer> stringIntegerTuple2 : counts) {
            System.out.println(stringIntegerTuple2._1+":"+stringIntegerTuple2._2);
        }
        sc.stop();
    }
// wordcount.txt
one
two two
three three three
four four four four
five five five five five
six six six six six six

重要的类

首先讲讲几个重要的类,方便接下来有目的性的去探索。介绍仅涉及到运行时分析的角度,不会详细讲解。

RDD

计算的真正开始时在RDD类的Action操作之后开始,这是典型的lazy模式。上述代码执行的开始也是在RDD的collect()函数里正式开始。而之前的flatMap,mapToPair()等操作都只是在构建DAG图。

SparkContext

SparkContext 是Driver 的核心,rdd里面的调用最终会转移到SparkContext的调用,SparkContext 使用DAGSchedular和TaskSchedular来进行DAG的解析。

DAGScheduler

DAGScheduler负责对DAG进行解析,划分调度阶段。一个调度阶段可以看成一组可以并行执行的任务。如本例的flatMap和后面的mapToPair()是一个调度阶段,后面的reduceByKey是另一个调度阶段(stage)。调度阶段是按照宽依赖来划分的。划分为调度阶段之后DAGScheduler将不同的调度阶段封装为不同的task发送给DAGScheduler。

TaskScheduler

TaskScheduler的工作是调度Task,接收到DAGSchedular发送过来的task之后TaskScheduler将这些Task分发给Executor执行。

ResultStage和ShuffleMapStage

ResultStage合ShuffleMapStage是解析DAG的结果,表示两种调度阶段。ResultStage表示最后一个RDD生成的调度阶段,ShuffleMapStage则表示之前的一个shuffle的调度阶段,因为有shuffle操作的前后两个操作不能放到一个调度阶段(因为是宽依赖,不能并行执行),所以这样划分。如果一个job没有shuffle操作,那么这个job就只有一个ResultStage,如果一个job有shuffle操作,那么这个job就存在至少一个shuffleMapStage和一个ResultStage。

执行的过程

提交阶段

提交DAG 到DAGScheduler

上例中调用了collect之后就触发了任务的执行。任务执行的第一个阶段是提交阶段,collect内部调用情况如下

RDD.scala
  /**
   * Return an array that contains all of the elements in this RDD.
   * 首先是调用了SparkContext的runJob方法
   */
  def collect(): Array[T] = withScope {
    val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
    Array.concat(results: _*)
  }

SparkContext.scala

	/**
	* 在SparkContext里面几段调用之后就转移到了这个函数里面,可以看出,这个函数主要是调用了dagSchedular的runJob
   * Run a function on a given set of partitions in an RDD and pass the results to the given
   * handler function. This is the main entry point for all actions in Spark. The allowLocal
   * flag specifies whether the scheduler can run the computation on the driver rather than
   * shipping it out to the cluster, for short actions like first().
   */
  def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      allowLocal: Boolean,
      resultHandler: (Int, U) => Unit) {
    if (stopped.get()) {
      throw new IllegalStateException("SparkContext has been shutdown")
    }
    val callSite = getCallSite
    val cleanedFunc = clean(func)
    logInfo("Starting job: " + callSite.shortForm)
    if (conf.getBoolean("spark.logLineage", false)) {
      logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
    }
    dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
      resultHandler, localProperties.get)
    progressBar.foreach(_.finishAll())
    rdd.doCheckpoint()
  }

DAGScheduler.scala
一层简单的封装调用,插入了输出日志的语句,对实际运行没有影响

def runJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      allowLocal: Boolean,
      resultHandler: (Int, U) => Unit,
      properties: Properties): Unit = {
    val start = System.nanoTime
    // 调用submitJob将job提交到事件队列
    val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
    waiter.awaitResult() match {
      case JobSucceeded =>
        logInfo("Job %d finished: %s, took %f s".format
          (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
      case JobFailed(exception: Exception) =>
        logInfo("Job %d failed: %s, took %f s".format
          (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
        throw exception
    }
  }

DAGSchedular.scala
这里是使用一个事件循环的帮助类,将submitJob这个任务放到了事件循环中

/**
   * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
   * can be used to block until the the job finishes executing or can be used to cancel the job.
   */
  def submitJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      allowLocal: Boolean,
      resultHandler: (Int, U) => Unit,
      properties: Properties): JobWaiter[U] = {
    // Check to make sure we are not launching a task on a partition that does not exist.
    val maxPartitions = rdd.partitions.length
    partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
      throw new IllegalArgumentException(
        "Attempting to access a non-existent partition: " + p + ". " +
          "Total number of partitions: " + maxPartitions)
    }

    val jobId = nextJobId.getAndIncrement()
    if (partitions.size == 0) {
      return new JobWaiter[U](this, jobId, 0, resultHandler)
    }

    assert(partitions.size > 0)
    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
    val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
    eventProcessLoop.post(JobSubmitted(
      jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter,
      SerializationUtils.clone(properties)))
    waiter
  }

DAGScheduler.scala
在这里接收到了事件,这是一个生产者,消费者队列的模式,上面的eventProcessLoop使用了一个阻塞队列,这里根据从队列中取出的不同任务采取不同的操作,可以从这个函数看到DAGScheduler在job的不同生命周期阶段的行为。

/**
   * The main event loop of the DAG scheduler.
   */
  override def onReceive(event: DAGSchedulerEvent): Unit = event match {
    case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
      dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
        listener, properties)

    case StageCancelled(stageId) =>
      dagScheduler.handleStageCancellation(stageId)

    case JobCancelled(jobId) =>
      dagScheduler.handleJobCancellation(jobId)

    case JobGroupCancelled(groupId) =>
      dagScheduler.handleJobGroupCancelled(groupId)

    case AllJobsCancelled =>
      dagScheduler.doCancelAllJobs()

    case ExecutorAdded(execId, host) =>
      dagScheduler.handleExecutorAdded(execId, host)

    case ExecutorLost(execId) =>
      dagScheduler.handleExecutorLost(execId, fetchFailed = false)

    case BeginEvent(task, taskInfo) =>
      dagScheduler.handleBeginEvent(task, taskInfo)

    case GettingResultEvent(taskInfo) =>
      dagScheduler.handleGetTaskResult(taskInfo)

    case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
      dagScheduler.handleTaskCompletion(completion)

    case TaskSetFailed(taskSet, reason) =>
      dagScheduler.handleTaskSetFailed(taskSet, reason)

    case ResubmitFailedStages =>
      dagScheduler.resubmitFailedStages()
  }

DAGScheduler将DAG解析为stage

在handleJobSubmitted方法里面就完成了从rdd到stage的转变,主要是在封装的newResultStage里面进行的。

  /**
   * Create a ResultStage -- either directly for use as a result stage, or as part of the
   * (re)-creation of a shuffle map stage in newOrUsedShuffleStage.  The stage will be associated
   * with the provided jobId.
   */
  private def newResultStage(
      rdd: RDD[_],
      numTasks: Int,
      jobId: Int,
      callSite: CallSite): ResultStage = {
    val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
    val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite)

    stageIdToStage(id) = stage
    // 这里解析了之前的stage
    updateJobIdStageIdMaps(jobId, stage)
    stage
  }

stage的解析是从后向前的解析,也就是从最后一个ResultStage向前倒推,寻找shuffle操作,然后根据shuffle操作生成ShuffleMapStage

  /**
   * Registers the given jobId among the jobs that need the given stage and
   * all of that stage's ancestors.
   */
  private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = {
    def updateJobIdStageIdMapsList(stages: List[Stage]) {
      if (stages.nonEmpty) {
        val s = stages.head
        s.jobIds += jobId
        jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
        // 这个方法是解析stage的主要算法的实现
        val parents: List[Stage] = getParentStages(s.rdd, jobId)
        // 将自身移除
        val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
        updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
      }
    }
    updateJobIdStageIdMapsList(List(stage))
  }

这里是个闭包调用,里面的getParentStages()方法是解析stages的实现

  /**
   * Get or create the list of parent stages for a given RDD. The stages will be assigned the
   * provided jobId if they haven't already been created with a lower jobId.
   */
  private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = {
    val parents = new HashSet[Stage]
    val visited = new HashSet[RDD[_]]
    // We are manually maintaining a stack here to prevent StackOverflowError
    // caused by recursively visiting
    val waitingForVisit = new Stack[RDD[_]]
    // 广度优先遍历依赖,建立stages
    def visit(r: RDD[_]) {
      if (!visited(r)) {
        visited += r
        // Kind of ugly: need to register RDDs with the cache here since
        // we can't do it in its constructor because # of partitions is unknown
        // 遍历rdd的所有依赖
        for (dep <- r.dependencies) {
          dep match {
          // 如果是shuffle操作,那么就在parents里面加一个shuffleMapStage
            case shufDep: ShuffleDependency[_, _, _] =>
              parents += getShuffleMapStage(shufDep, jobId)
          // 不是就加入遍历队列,等待接下来的访问
            case _ =>
              waitingForVisit.push(dep.rdd)
          }
        }
      }
    }
    waitingForVisit.push(rdd)
    while (waitingForVisit.nonEmpty) {
      visit(waitingForVisit.pop())
    }
    parents.toList
  }

上面的代码使用图的广度优先遍历算法便利了DAG(也就是RDD的依赖图),然后根据依赖的操作是否为shuffle操作创建stage,注意这里是找出了所有的被ResultStage依赖的ShuffleMapStage,但是不一定就是直接依赖,间接依赖也放到parents里面了,那么每个parent的依赖是怎么处理的呢?是递归调用了getParentStages()来进行处理的。

  /**
   * Get or create a shuffle map stage for the given shuffle dependency's map side.
   * The jobId value passed in will be used if the stage doesn't already exist with
   * a lower jobId (jobId always increases across jobs.)
   * getParentStages()里面创建新的shuffleStage的时候调用了这个函数
   */
  private def getShuffleMapStage(
      shuffleDep: ShuffleDependency[_, _, _],
      jobId: Int): ShuffleMapStage = {
    shuffleToMapStage.get(shuffleDep.shuffleId) match {
      case Some(stage) => stage
      case None =>
        // We are going to register ancestor shuffle dependencies
        registerShuffleDependencies(shuffleDep, jobId)
        // Then register current shuffleDep
        // 创建一个shuffleStage
        val stage = newOrUsedShuffleStage(shuffleDep, jobId)
        shuffleToMapStage(shuffleDep.shuffleId) = stage

        stage
    }
  }

  /**
   * Create a shuffle map Stage for the given RDD.  The stage will also be associated with the
   * provided jobId.  If a stage for the shuffleId existed previously so that the shuffleId is
   * present in the MapOutputTracker, then the number and location of available outputs are
   * recovered from the MapOutputTracker
   */
  private def newOrUsedShuffleStage(
      shuffleDep: ShuffleDependency[_, _, _],
      jobId: Int): ShuffleMapStage = {
    val rdd = shuffleDep.rdd
    val numTasks = rdd.partitions.size
    // 创建新的shuffleMapStage
    val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite)
    if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
      val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
      val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
      for (i <- 0 until locs.size) {
        stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing
      }
      stage.numAvailableOutputs = locs.count(_ != null)
    } else {
      // Kind of ugly: need to register RDDs with the cache and map output tracker here
      // since we can't do it in the RDD constructor because # of partitions is unknown
      logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
      mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size)
    }
    stage
  }

  /**
  * 这个方法最后又回到了updateJobIdStageIdMaps()这个方法的调用,形成了递归
  */
  private def newShuffleMapStage(
      rdd: RDD[_],
      numTasks: Int,
      shuffleDep: ShuffleDependency[_, _, _],
      jobId: Int,
      callSite: CallSite): ShuffleMapStage = {
    val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
    val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages,
      jobId, callSite, shuffleDep)

    stageIdToStage(id) = stage
    updateJobIdStageIdMaps(jobId, stage)
    stage
  }

截止到这里就完成了stage的解析,之后的工作就是将stage转化成tasks然后发送给TaskSchedular调度

  /** Submits stage, but first recursively submits any missing parents. */
  private def submitStage(stage: Stage) {
    val jobId = activeJobForStage(stage)
    if (jobId.isDefined) {
      logDebug("submitStage(" + stage + ")")
      if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
        val missing = getMissingParentStages(stage).sortBy(_.id)
        logDebug("missing: " + missing)
        // 这里会判断是否有依赖的stage,要是有依赖的,那么要让依赖的stage先执行
        if (missing.isEmpty) {
          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
          // 没有依赖的stage,提交这个stage
          submitMissingTasks(stage, jobId.get)
        } else {
          for (parent <- missing) {
            submitStage(parent)
          }
          waitingStages += stage
        }
      }
    } else {
      abortStage(stage, "No active job for stage " + stage.id)
    }
  }

上面是个简单的深度优先遍历提交的过程,先判断当前stage是否有依赖的stage,如果有的话就先提交依赖的stage,这个方法在有新的stage调度完成之后会尝试添加新的stage进去,

// DAGSchedular.handleTaskCompletion 片段,当任务完成时试图重新提交stage
       case smt: ShuffleMapTask =>
            val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
            updateAccumulators(event)
            val status = event.result.asInstanceOf[MapStatus]
            val execId = status.location.executorId
            logDebug("ShuffleMapTask finished on " + execId)
            if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
              logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
            } else {
              shuffleStage.addOutputLoc(smt.partitionId, status)
            }
            if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
              markStageAsFinished(shuffleStage)
              logInfo("looking for newly runnable stages")
              logInfo("running: " + runningStages)
              logInfo("waiting: " + waitingStages)
              logInfo("failed: " + failedStages)

              // We supply true to increment the epoch number here in case this is a
              // recomputation of the map outputs. In that case, some nodes may have cached
              // locations with holes (from when we detected the error) and will need the
              // epoch incremented to refetch them.
              // TODO: Only increment the epoch number if this is not the first time
              //       we registered these map outputs.
              mapOutputTracker.registerMapOutputs(
                shuffleStage.shuffleDep.shuffleId,
                shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
                changeEpoch = true)

              clearCacheLocs()
              if (shuffleStage.outputLocs.contains(Nil)) {
                // Some tasks had failed; let's resubmit this shuffleStage
                // TODO: Lower-level scheduler should also deal with this
                logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
                  ") because some of its tasks had failed: " +
                  shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty)
                      .map(_._2).mkString(", "))
                // 执行出错的时候重新提交
                submitStage(shuffleStage)
              }else {
                val newlyRunnable = new ArrayBuffer[Stage]
                for (shuffleStage <- waitingStages) {
                  logInfo("Missing parents for " + shuffleStage + ": " +
                    getMissingParentStages(shuffleStage))
                }
                for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty)
                {
                  newlyRunnable += shuffleStage
                }
                waitingStages --= newlyRunnable
                runningStages ++= newlyRunnable
                for {
                  shuffleStage <- newlyRunnable.sortBy(_.id)
                  jobId <- activeJobForStage(shuffleStage)
                } {
                  logInfo("Submitting " + shuffleStage + " (" +
                    shuffleStage.rdd + "), which is now runnable")
                  //提交正在等待的stage
                  submitMissingTasks(shuffleStage, jobId)
                }
              }

如果一个stage已经完成了,会触发对新的stage的调度,那么在getMissingParentStages里面就会被过滤掉。这个时候原先在等待的stage就可以而将stage转为task的主要逻辑在submitMissingTasks(stage, jobId.get)这个方法里面。

/** Called when stage's parents are available and we can now do its task. */
  private def submitMissingTasks(stage: Stage, jobId: Int) {
    ……
    // 将stage 转为 task,每个partition上都有一个task
    val tasks: Seq[Task[_]] = stage match {
      case stage: ShuffleMapStage =>
        partitionsToCompute.map { id =>
          // 获取当前partition最优的machine
          val locs = getPreferredLocs(stage.rdd, id)
          // 获取当前的partition
          val part = stage.rdd.partitions(id)
          new ShuffleMapTask(stage.id, taskBinary, part, locs)
        }

      case stage: ResultStage =>
        val job = stage.resultOfJob.get
        partitionsToCompute.map { id =>
          val p: Int = job.partitions(id)
          val part = stage.rdd.partitions(p)
          val locs = getPreferredLocs(stage.rdd, p)
          new ResultTask(stage.id, taskBinary, part, locs, id)
        }
    }

    if (tasks.size > 0) {
      logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
      stage.pendingTasks ++= tasks
      logDebug("New pending tasks: " + stage.pendingTasks)
      // 一个stage的tasks将会作为一个整体被提交
      taskScheduler.submitTasks(
        new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
      stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
    } else {
      // Because we posted SparkListenerStageSubmitted earlier, we should mark
      // the stage as completed here in case there are no tasks to run
      markStageAsFinished(stage, None)

      val debugString = stage match {
        case stage: ShuffleMapStage =>
          s"Stage ${stage} is actually done; " +
            s"(available: ${stage.isAvailable}," +
            s"available outputs: ${stage.numAvailableOutputs}," +
            s"partitions: ${stage.numPartitions})"
        case stage : ResultStage =>
          s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
      }
      logDebug(debugString)
    }
  }

之后的调用时在TaskSchedulerImpl.scala里面

override def submitTasks(taskSet: TaskSet) {
    val tasks = taskSet.tasks
    logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
    this.synchronized {
      // 首先是创建了TaskSetManager,这个taskSet的执行由这个manager负责
      val manager = createTaskSetManager(taskSet, maxTaskFailures)
      activeTaskSets(taskSet.id) = manager
      // 将这个taskset对应的manager添加到系统调度池,由系统调度
      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

      if (!isLocal && !hasReceivedTask) {
        starvationTimer.scheduleAtFixedRate(new TimerTask() {
          override def run() {
            if (!hasLaunchedTask) {
              logWarning("Initial job has not accepted any resources; " +
                "check your cluster UI to ensure that workers are registered " +
                "and have sufficient resources")
            } else {
              this.cancel()
            }
          }
        }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
      }
      hasReceivedTask = true
    }
    // 分配资源并运行
    backend.reviveOffers()
  }

资源分配

backend.reviveOffers的调用最后调用的方法是CoarseGrainedSchedulerBackend里面实现的,该方法向DriverEndPoint发送了消息,而DriverEndpoint收到消息之后调用了makeOffers方法

  override def reviveOffers() {
    driverEndpoint.send(ReviveOffers)
  }
override def receive: PartialFunction[Any, Unit] = {
      case StatusUpdate(executorId, taskId, state, data) =>
        scheduler.statusUpdate(taskId, state, data.value)
        if (TaskState.isFinished(state)) {
          executorDataMap.get(executorId) match {
            case Some(executorInfo) =>
              executorInfo.freeCores += scheduler.CPUS_PER_TASK
              makeOffers(executorId)
            case None =>
              // Ignoring the update since we don't know about the executor.
              logWarning(s"Ignored task status update ($taskId state $state) " +
                "from unknown executor $sender with ID $executorId")
          }
        }

      case ReviveOffers =>
        makeOffers()

      case KillTask(taskId, executorId, interruptThread) =>
        executorDataMap.get(executorId) match {
          case Some(executorInfo) =>
            executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread))
          case None =>
            // Ignoring the task kill since the executor is not registered.
            logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
        }
    }

makeOffers()方法对任务集合分配资源并提交到TaskSchedulerImpl运行

// Make fake resource offers on all executors
    def makeOffers() {
    // 对每个executor实例化了一个WorkerOffer
      launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) =>
        new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
      }.toSeq))
    }

scheduler的resourceOffers进行的是资源分配工作,

  /**
   * Called by cluster manager to offer resources on slaves. We respond by asking our active task
   * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
   * that tasks are balanced across the cluster.
   * cluster 调用这个方法来为slaves分配资源,按照优先级询问task sets,使用轮询方法来
   * 分配任务,集群负载均衡
   */
  def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
    // Mark each slave as alive and remember its hostname
    // Also track if new executor is added
    var newExecAvail = false
    // 记录可用executor列表
    for (o <- offers) {
      executorIdToHost(o.executorId) = o.host
      activeExecutorIds += o.executorId
      if (!executorsByHost.contains(o.host)) {
        executorsByHost(o.host) = new HashSet[String]()
        executorAdded(o.executorId, o.host)
        newExecAvail = true
      }
      for (rack <- getRackForHost(o.host)) {
        hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host
      }
      // 打乱offers避免tasks 放到相同的workers set里面
      // Randomly shuffle offers to avoid always placing tasks on the same set of workers.
    val shuffledOffers = Random.shuffle(offers)
    // Build a list of tasks to assign to each worker.
    val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
    val availableCpus = shuffledOffers.map(o => o.cores).toArray
    // 获取按照调度策略排好序的task sets
    val sortedTaskSets = rootPool.getSortedTaskSetQueue
    for (taskSet <- sortedTaskSets) {
      logDebug("parentName: %s, name: %s, runningTasks: %s".format(
        taskSet.parent.name, taskSet.name, taskSet.runningTasks))
        // 如果有新加入的executor ,要重新计算数据本地性
      if (newExecAvail) {
        taskSet.executorAdded()
      }
    }

    // Take each TaskSet in our scheduling order, and then offer it each node in increasing order
    // of locality levels so that it gets a chance to launch local tasks on all of them.
    // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
    var launchedTask = false
    for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
      do {
        launchedTask = resourceOfferSingleTaskSet(
            taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
      } while (launchedTask)
    }

    if (tasks.size > 0) {
      hasLaunchedTask = true
    }
    return tasks
    }

上面的代码涉及到了一个概念,叫数据本地性,数据本地性指的是数据读取的方便程度,有5个级别,分别对应上面的5个优先级;
PROCESS_LOCAL:进程本地化,代码和数据在同一个进程中,也就是在同一个executor中;计算数据的task由executor执行,数据在executor的BlockManager中,性能最好。
NODE_LOCAL:节点本地化,代码和数据在同一个节点中;比如说,数据作为一个HDFS block块,就在节点上,而task在节点上某个executor中运行;或者是,数据和task在一个节点上的不同executor中,数据需要在进程间进行传输
NO_PREF:对于task来说,数据从哪里获取都一样,没有好坏之分,比如从数据库中获取数据
RACK_LOCAL:机架本地化,数据和task在一个机架的两个节点上,数据需要通过网络在节点之间进行传输;
ANY:数据和task可能在集群中的任何地方,而且不在一个机架中,性能最差。
上面所说的数据本地性变化指的就是在executor变化的时候,可能有的新的本地化的数据可用了之类的变化。
按照优先级为task set分配了offer之后就使用launchTasks提交运行task set

// Launch tasks returned by a set of resource offers
    def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
      for (task <- tasks.flatten) {
        val ser = SparkEnv.get.closureSerializer.newInstance()
        val serializedTask = ser.serialize(task)
        if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
          val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
          scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
            try {
              var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
                "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
                "spark.akka.frameSize or using broadcast variables for large values."
              msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
                AkkaUtils.reservedSizeBytes)
              taskSet.abort(msg)
            } catch {
              case e: Exception => logError("Exception in error callback", e)
            }
          }
        }
        else {
          val executorData = executorDataMap(task.executorId)
          executorData.freeCores -= scheduler.CPUS_PER_TASK
          executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
        }
      }
    }

上面的代码做了两件事,1 将task序列化,2检查序列化之后的task是否超过了akka 能发送的最大数据帧的大小,是的话抛出异常,不是就使用ExecutorEndPoint发送数据。注意这里是先遍历了task set里面的task 然后一个个的发送的。

任务执行

Executor收到任务之后调用了launchTask()方法来运行任务,

 def launchTask(
      context: ExecutorBackend,
      taskId: Long,
      attemptNumber: Int,
      taskName: String,
      serializedTask: ByteBuffer): Unit = {
      // 将task 封装到task runner里面
    val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
      serializedTask)
    runningTasks.put(taskId, tr)
    // 使用线程池运行task runner
    threadPool.execute(tr)
  }

既然使用了线程池运行,那么核心代码应该在task runner 的run方法中:

    override def run(): Unit = {
      val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
      val deserializeStartTime = System.currentTimeMillis()
      Thread.currentThread.setContextClassLoader(replClassLoader)
      val ser = env.closureSerializer.newInstance()
      logInfo(s"Running $taskName (TID $taskId)")
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      var taskStart: Long = 0
      startGCTime = computeTotalGcTime()

      try {
      // 首先反序列化任务执行需要的文件,jar以及任务代码
        val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
        updateDependencies(taskFiles, taskJars)
        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
        task.setTaskMemoryManager(taskMemoryManager)

        // If this task has been killed before we deserialized it, let's quit now. Otherwise,
        // continue executing the task.
        if (killed) {
          // Throw an exception rather than returning, because returning within a try{} block
          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
          // for the task.
          throw new TaskKilledException
        }

        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        env.mapOutputTracker.updateEpoch(task.epoch)

        // Run the actual task and measure its runtime.
        taskStart = System.currentTimeMillis()
        val value = try {
        // 调用task的run运行任务
          task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
        } finally {
          // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
          // when changing this, make sure to update both copies.
          val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
          if (freedMemory > 0) {
            val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
              throw new SparkException(errMsg)
            } else {
              logError(errMsg)
            }
          }
        }
        val taskFinish = System.currentTimeMillis()

        // If the task has been killed, let's fail it.
        if (task.killed) {
          throw new TaskKilledException
        }

        val resultSer = env.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()

        for (m <- task.metrics) {
          // Deserialization happens in two parts: first, we deserialize a Task object, which
          // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
          m.setExecutorDeserializeTime(
            (taskStart - deserializeStartTime) + task.executorDeserializeTime)
          // We need to subtract Task.run()'s deserialization time to avoid double-counting
          m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
          m.setJvmGCTime(computeTotalGcTime() - startGCTime)
          m.setResultSerializationTime(afterSerialization - beforeSerialization)
        }

        val accumUpdates = Accumulators.values
        val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
        val serializedDirectResult = ser.serialize(directResult)
        val resultSize = serializedDirectResult.limit

        // directSend = sending directly back to the driver
        val serializedResult: ByteBuffer = {
          if (maxResultSize > 0 && resultSize > maxResultSize) {
            logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
              s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
              s"dropping it.")
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
          } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
            logInfo(
              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
          } else {
            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
            serializedDirectResult
          }
        }

        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

      } catch {
        case ffe: FetchFailedException =>
          val reason = ffe.toTaskEndReason
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

        case _: TaskKilledException | _: InterruptedException if task.killed =>
          logInfo(s"Executor killed $taskName (TID $taskId)")
          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))

        case cDE: CommitDeniedException =>
          val reason = cDE.toTaskEndReason
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

        case t: Throwable =>
          // Attempt to exit cleanly by informing the driver of our failure.
          // If anything goes wrong (or this was a fatal exception), we will delegate to
          // the default uncaught exception handler, which will terminate the Executor.
          logError(s"Exception in $taskName (TID $taskId)", t)

          val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
            task.metrics.map { m =>
              m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
              m.setJvmGCTime(computeTotalGcTime() - startGCTime)
              m
            }
          }
          val taskEndReason = new ExceptionFailure(t, metrics)
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason))

          // Don't forcibly exit unless the exception was inherently fatal, to avoid
          // stopping other tasks unnecessarily.
          if (Utils.isFatalError(t)) {
            SparkUncaughtExceptionHandler.uncaughtException(t)
          }

      } finally {
        // Release memory used by this thread for shuffles
        env.shuffleMemoryManager.releaseMemoryForThisThread()
        // Release memory used by this thread for unrolling blocks
        env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
        // Release memory used by this thread for accumulators
        Accumulators.clear()
        runningTasks.remove(taskId)
      }
    }
  }

上面的run方法的代码很多,但是真正实现任务执行的代码只有task.run()这一句,另外的都是一些分支覆盖的语句,对task执行中可能遇到的各种情况进行控制。直接看task.run()

  final def run(taskAttemptId: Long, attemptNumber: Int): T = {
    context = new TaskContextImpl(
      stageId = stageId,
      partitionId = partitionId,
      taskAttemptId = taskAttemptId,
      attemptNumber = attemptNumber,
      taskMemoryManager = taskMemoryManager,
      runningLocally = false)
    TaskContext.setTaskContext(context)
    context.taskMetrics.setHostname(Utils.localHostName())
    taskThread = Thread.currentThread()
    if (_killed) {
      kill(interruptThread = false)
    }
    try {
      runTask(context)
    } finally {
      context.markTaskCompleted()
      TaskContext.unset()
    }
  }

封装了一层,调用了task的runTask方法,注意这个方法在Task类里面并没有实现,而是在子类ShuffleMapTask和ResultTask实现,首先看ShuffleMapTask的实现:

  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    // 获取rdd和依赖
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    // 创建监测
    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      // 执行计算
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      return writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

这里执行了任务,shuffleMapTask将执行的结果写入了BlockManager,返回的是一个MapStatus对象,这个对象存放了存储的相关信息,而不是结果本身。
ResultTask的runTask()方法

override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    metrics = Some(context.taskMetrics)
    // 直接返回计算结果
    func(context, rdd.iterator(partition, context))
  }

这里的就直接返回了计算结果,而不是先写入到小文件。

执行之后

执行之后,首先根据计算结果的大小,TaskRunner里面的run方法里实现了不同的处理策略

        // directSend = sending directly back to the driver
        val serializedResult: ByteBuffer = {
          if (maxResultSize > 0 && resultSize > maxResultSize) {
            logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
              s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
              s"dropping it.")
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
          } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
            logInfo(
              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
          } else {
            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
            serializedDirectResult
          }
        }
  1. 如果结果大于1G,那么直接丢弃
  2. 如果结果在[1G,128MB-200KB]之间,那么存放到BlockManager里面,然后通过Netty将blockId发送给DriverEndPoint,是通过下面的statusUpdate()在CoarseGrainedExecutorBackend里面的实现做的
    接下来还是在这个方法里面,根据执行的结果不同,会使用statusUpdate方法向DriverEndPoint发送不同的消息,DriverEndPoint的receive方法在收到消息之后会调用TaskSchedulerImpl的statusUpdate方法根据执行结果的不同采取不同的操作
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
    var failedExecutor: Option[String] = None
    synchronized {
      try {
        if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
          // We lost this entire executor, so remember that it's gone
          val execId = taskIdToExecutorId(tid)
          if (activeExecutorIds.contains(execId)) {
            removeExecutor(execId)
            failedExecutor = Some(execId)
          }
        }
        taskIdToTaskSetId.get(tid) match {
          case Some(taskSetId) =>
            if (TaskState.isFinished(state)) {
              taskIdToTaskSetId.remove(tid)
              taskIdToExecutorId.remove(tid)
            }
            activeTaskSets.get(taskSetId).foreach { taskSet =>
              if (state == TaskState.FINISHED) {
                taskSet.removeRunningTask(tid)
                taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
              } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
                taskSet.removeRunningTask(tid)
                taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
              }
            }
          case None =>
            logError(
              ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
               "likely the result of receiving duplicate task finished status updates)")
              .format(state, tid))
        }
      } catch {
        case e: Exception => logError("Exception in statusUpdate", e)
      }
    }
    // Update the DAGScheduler without holding a lock on this, since that can deadlock
    if (failedExecutor.isDefined) {
      dagScheduler.executorLost(failedExecutor.get)
      backend.reviveOffers()
    }
  }

如果state是LOST并且仍然还在运行,那么标记这个Executor是lost,记录这个失效的executor。
如果是FINISHED,使用taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)来处理

case directResult: DirectTaskResult[_] =>
             if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
               return
             }
             // deserialize "value" without holding any lock so that it won't block other threads.
             // We should call it here, so that when it's called again in
             // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
             directResult.value()
             (directResult, serializedData.limit())
           case IndirectTaskResult(blockId, size) =>
             if (!taskSetManager.canFetchMoreResults(size)) {
               // dropped by executor if size is larger than maxResultSize
               sparkEnv.blockManager.master.removeBlock(blockId)
               return
             }
             logDebug("Fetching indirect task result for TID %s".format(tid))
             scheduler.handleTaskGettingResult(taskSetManager, tid)
             val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
             if (!serializedTaskResult.isDefined) {
               /* We won't be able to get the task result if the machine that ran the task failed
                * between when the task ended and when we tried to fetch the result, or if the
                * block manager had to flush the result. */
               scheduler.handleFailedTask(
                 taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
               return
             }

这个方法根据结果的类型而读取结果,如果是Indirect那么就通过blockManager来读取数据
如果task执行的结果是FAILED,KILLED,LOST,那么就是用 taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)来处理
在最后要对lost的task 重新调度

// Update the DAGScheduler without holding a lock on this, since that can deadlock
    if (failedExecutor.isDefined) {
      dagScheduler.executorLost(failedExecutor.get)
      backend.reviveOffers()
    }

shuffleMapTask任务完成之后需要通知下游的调度阶段,将当前的结果当作下游的输入,实现的方法是通过在taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)里面的
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
调用的过程如下:
spark 任务完成

这里的调用和提交submitJob的时候类似,通过生产者,消费者的模式来进行调度。最后的调用在handleTaskCompletion里面做,这里分别采取了不同的处理,如果是ResultTask:

task match {
          case rt: ResultTask[_, _] =>
            // Cast to ResultStage here because it's part of the ResultTask
            // TODO Refactor this out to a function that accepts a ResultStage
            val resultStage = stage.asInstanceOf[ResultStage]
            resultStage.resultOfJob match {
              case Some(job) =>
                if (!job.finished(rt.outputId)) {
                  updateAccumulators(event)
                  job.finished(rt.outputId) = true
                  job.numFinished += 1
                  // If the whole job has finished, remove it
                  if (job.numFinished == job.numPartitions) {
                    markStageAsFinished(resultStage)
                    cleanupStateForJobAndIndependentStages(job)
                    listenerBus.post(
                      SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
                  }

                  // taskSucceeded runs some user code that might throw an exception. Make sure
                  // we are resilient against that.
                  try {
                    job.listener.taskSucceeded(rt.outputId, event.result)
                  } catch {
                    case e: Exception =>
                      // TODO: Perhaps we want to mark the resultStage as failed?
                      job.listener.jobFailed(new SparkDriverExecutionException(e))
                  }
                }
              case None =>
                logInfo("Ignoring result from " + rt + " because its job has finished")
            }

上面的代码最重要的是这三行

markStageAsFinished(resultStage)
cleanupStateForJobAndIndependentStages(job)
listenerBus.post(
  SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))

上面首先是判断了job的执行结果是不是成功了,要是成功了就执行上面的这三行代码,标记这个stage已经完成了,清理依赖,发送消息给监听总线告知作业执行完毕

如果是ShuffleMapTask

          case smt: ShuffleMapTask =>
            val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
            updateAccumulators(event)
            val status = event.result.asInstanceOf[MapStatus]
            val execId = status.location.executorId
            logDebug("ShuffleMapTask finished on " + execId)
            if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
              logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
            } else {
              shuffleStage.addOutputLoc(smt.partitionId, status)
            }
            if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
              markStageAsFinished(shuffleStage)
              logInfo("looking for newly runnable stages")
              logInfo("running: " + runningStages)
              logInfo("waiting: " + waitingStages)
              logInfo("failed: " + failedStages)

              // We supply true to increment the epoch number here in case this is a
              // recomputation of the map outputs. In that case, some nodes may have cached
              // locations with holes (from when we detected the error) and will need the
              // epoch incremented to refetch them.
              // TODO: Only increment the epoch number if this is not the first time
              //       we registered these map outputs.
              mapOutputTracker.registerMapOutputs(
                shuffleStage.shuffleDep.shuffleId,
                shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
                changeEpoch = true)

              clearCacheLocs()
              if (shuffleStage.outputLocs.contains(Nil)) {
                // Some tasks had failed; let's resubmit this shuffleStage
                // TODO: Lower-level scheduler should also deal with this
                logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
                  ") because some of its tasks had failed: " +
                  shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty)
                      .map(_._2).mkString(", "))
                submitStage(shuffleStage)
              } else {
                val newlyRunnable = new ArrayBuffer[Stage]
                for (shuffleStage <- waitingStages) {
                  logInfo("Missing parents for " + shuffleStage + ": " +
                    getMissingParentStages(shuffleStage))
                }
                for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty)
                {
                  newlyRunnable += shuffleStage
                }
                waitingStages --= newlyRunnable
                runningStages ++= newlyRunnable
                for {
                  shuffleStage <- newlyRunnable.sortBy(_.id)
                  jobId <- activeJobForStage(shuffleStage)
                } {
                  logInfo("Submitting " + shuffleStage + " (" +
                    shuffleStage.rdd + "), which is now runnable")
                  submitMissingTasks(shuffleStage, jobId)
                }
              }
            }
          }

对于shuffleMapTask关键的代码是下面这些

val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
updateAccumulators(event)
val status = event.result.asInstanceOf[MapStatus]
shuffleStage.addOutputLoc(smt.partitionId, status)

这样就把输出的数据的位置封装到shuffleStage里面了,接下来将这个MapStatus注册到MapOutputTrackerMaster里面,这个shuffleStage的调度就完成了

mapOutputTracker.registerMapOutputs(
  shuffleStage.shuffleDep.shuffleId,
  shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
  changeEpoch = true)

猜你喜欢

转载自blog.csdn.net/fate_killer_liu_jie/article/details/84944379