Spark:Task原理剖析与源码分析

在Spark中,一个应用程序要想被执行,肯定要经过以下的步骤:
在这里插入图片描述
从这个路线得知,最终一个job是依赖于分布在集群不同节点中的task,通过并行或者并发的运行来完成真正的工作。由此可见,一个个的分布式的task才是Spark的真正执行者。下面先来张task运行框架整体的对Spark的task运行有个大概的了解。
在这里插入图片描述

task运行之前的工作是Worker启动Executor,接着Executor准备好一切运行环境,并向Driver反向注册,最终Driver向Executor发送LunchTask事件消息,从Executor接受到LanchTask那一刻起,task就一发不可收拾了,开始通过java线程来进行以后的工作。当然了,在task正式工作之前,还有一些工作,比如根据stage算法划分好stage,根据task最佳位置计算算法寻找到task的最佳位置(第一期盼都是希望能够在同一个节点的同一个进程中有task所需要的需要,第二才是同一节点的不同进程,第三才是同一机架的不同节点,第四才是不同机架)。这样做的目的是减少网络通信的开销,节省CPU资源,提高系统性能。

task以下几点:

  1. 通过网络拉取运行所需的资源,并反序列化(由于多个task运行在多个Executor中,都是并行运行的,或者并发运行的,一个stage的task,处理的RDD是一样的,这是通过广播变量来完成的)
  2. 获取shuffleManager,从shuffleManager中获取shuffleWriter(shuffleWriter用于后面的数据处理并把返回的数据结果写入磁盘)
  3. 调用rdd.iterator(),并传入当前task要处理的partition(针对RDD的某个partition执行自定义的算子或逻辑函数,返回的数据都是通过上面生成的ShuffleWriter,经过HashPartitioner[默认是这个]分区之后写入对应的分区backet,其实就是写入磁盘文件中)
  4. 封装数据结果为MapStatus ,发送给MapOutputTracker,供ResultTask拉取。(MapStatus里面封装了ShuffleMaptask计算后的数据和存储位置地址等数据信息。其实也就是BlockManager相关信息,BlockManager 是Spark底层的内存,数据,磁盘数据管理的组件)
  5. ResultTask拉取ShuffleMapTask的结果数据(经过2/3/4步骤之后的结果)

实现这个过程,task有ShuffleMapTask和ResultTask两个子类task来支撑,前者是用于通过各种map算子和自定义函数转换RDD。后者主要是触发了action操作,把map阶段后的新的RDD拉取过去,再执行我们自定义的函数体,实现各种业务功能。

源码分析:
第一步:接收Driver端发来的消息
源码地址:org.apache.spark.executor.Executor.scala

  def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
    // 对于每一个task都需要创建一个taskRunner 【线程】
    // TaskRunner实际上是继承Java的Runnable接口
    val tr = new TaskRunner(context, taskDescription)
    // 将TaskRunner放入内存缓存中,runningTasks维护运行任务列表。
    runningTasks.put(taskDescription.taskId, tr)
    //将task封装在一个线程中(TaskRunner),将线程丢入线程池中,然后执行
    // 线程池是实现排队机制的,如果线程池内的线程暂时没有空闲,放入的线程就会排队
    threadPool.execute(tr)
  }

第二步:TaskRunner执行Task

  /**
   * task运行的工作原理
   */
  class TaskRunner(
    execBackend: ExecutorBackend,
    private val taskDescription: TaskDescription)
    extends Runnable {

......

    /**
     * 线程 执行run方法
     */
    override def run(): Unit = {
      threadId = Thread.currentThread.getId
      Thread.currentThread.setName(threadName)
      val threadMXBean = ManagementFactory.getThreadMXBean
      //为我们的Task创建内存管理器
      val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
      //记录反序列化时间
      val deserializeStartTime = System.currentTimeMillis()
      val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
      //加载具体类时需要用到ClassLoader
      Thread.currentThread.setContextClassLoader(replClassLoader)
      //创建序列化器
      val ser = env.closureSerializer.newInstance()
      logInfo(s"Running $taskName (TID $taskId)")
     //调用ExecutorBackend#statusUpdate向Driver发信息汇报当前状态
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      //记录运行时间和GC信息
      var taskStartTime: Long = 0
      var taskStartCpu: Long = 0
      startGCTime = computeTotalGcTime()

      try {
        // Must be set before updateDependencies() is called, in case fetching dependencies
        // requires access to properties contained within (e.g. for access control).
        Executor.taskDeserializationProps.set(taskDescription.properties)
        //通过网络通讯,将需要的文件  资源 jar 拷贝
        updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
        //反序列化Task
        //这里用到了java的ClassLoader,因为java的ClassLoader可以干很多事情,比如,用反射的方式来动态加载一个类,创建这个类的对象,
        //可以用于对指定上下文的相关资源,进行加载和读取
        task = ser.deserialize[Task[Any]](
          taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
        task.localProperties = taskDescription.properties
        task.setTaskMemoryManager(taskMemoryManager)

        // If this task has been killed before we deserialized it, let's quit now. Otherwise,
        // continue executing the task.
        val killReason = reasonIfKilled
        if (killReason.isDefined) {
          // Throw an exception rather than returning, because returning within a try{} block
          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
          // for the task.
          throw new TaskKilledException(killReason.get)
        }

        // The purpose of updating the epoch here is to invalidate executor map output status cache
        // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
        // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
        // we don't need to make any special calls here.
        if (!isLocal) {
          logDebug("Task " + taskId + "'s epoch is " + task.epoch)
          env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
        }

        // Run the actual task and measure its runtime.
        // 计算出task开始的时间
        taskStartTime = System.currentTimeMillis()
        taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        var threwException = true
        // 执行task,用的是task的run()方法
        /**
         * 这里的value,对于ShuffleMapTask来说,其实就是MapStatus,封装了ShuffleMapTask计算的数据,输出的位置
         * 后面还是一个ShuffleMapTask,那么就会去联系MapOutputTracker,来获取上一个ShuffleMapTasks的输出位置,然后通过网络拉取数据
         * ResultTask,也是一样的
         */
        val value = Utils.tryWithSafeFinally {
          val res = task.run(
            taskAttemptId = taskId,
            attemptNumber = taskDescription.attemptNumber,
            metricsSystem = env.metricsSystem)
          threwException = false
          res
        } {
          //清理所有分配的内存和分页,并检测是否有内存泄漏
          val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
          val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()

          if (freedMemory > 0 && !threwException) {
            val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
              throw new SparkException(errMsg)
            } else {
              logWarning(errMsg)
            }
          }

          if (releasedLocks.nonEmpty && !threwException) {
            val errMsg =
              s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
                releasedLocks.mkString("[", ", ", "]")
            if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
              throw new SparkException(errMsg)
            } else {
              logInfo(errMsg)
            }
          }
        }
        task.context.fetchFailed.foreach { fetchFailure =>
          // uh-oh.  it appears the user code has caught the fetch-failure without throwing any
          // other exceptions.  Its *possible* this is what the user meant to do (though highly
          // unlikely).  So we will log an error and keep going.
          logError(s"TID ${taskId} completed successfully though internally it encountered " +
            s"unrecoverable fetch failures!  Most likely this means user code is incorrectly " +
            s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
        }
        // 计算出task结束的时间
        val taskFinish = System.currentTimeMillis()
        val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L

        // If the task has been killed, let's fail it.
        task.context.killTaskIfInterrupted()
        
        // 这个,其实就是针对MapStatus进行了各种序列化和封装,因为后面要发送给Driver(通过网络)
        val resultSer = env.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()

        // Deserialization happens in two parts: first, we deserialize a Task object, which
        // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
        /**
         * 这里计算task相关的统计信息,包括 反序列化耗时长、Java虚拟机GC耗时长、数据结果序列化耗时长
         * 这些指标都会在SparkUI上显示
         */
        task.metrics.setExecutorDeserializeTime(
          (taskStartTime - deserializeStartTime) + task.executorDeserializeTime)
        task.metrics.setExecutorDeserializeCpuTime(
          (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
        // We need to subtract Task.run()'s deserialization time to avoid double-counting
        task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime)
        task.metrics.setExecutorCpuTime(
          (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
        task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
        task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)

   ......

        // directSend = sending directly back to the driver
        val serializedResult: ByteBuffer = {
          //对直接返回的结果对象大小进行判断
          if (maxResultSize > 0 && resultSize > maxResultSize) {
            //大于最大限制1G,直接丢弃ResultTask
            logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
              s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
              s"dropping it.")
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
          } else if (resultSize > maxDirectResultSize) {
            //结果大小大于设定的阀值,则放入BlockManager中
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId,
              new ChunkedByteBuffer(serializedDirectResult.duplicate()),
              StorageLevel.MEMORY_AND_DISK_SER)
            logInfo(
              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
            //返回非直接返回给Driver的对象TaskResultTask
              ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
          } else {
            //结果不大,直接传回给Driver
            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
            serializedDirectResult
          }
        }

        setTaskFinishedAndClearInterruptStatus()
        // 调用了executor所在的CoarseGrainedExecutorBackend的statusUpdate()方法
        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

      } catch {
      
       ......
      
      }
      } finally {
        runningTasks.remove(taskId)
      }
    }

第三步:updateDependencies()方法

  private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) {
    //获取Hadoop配置文件
    lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
    //多线程并发访问同步
    /**
     * 这里,使用java的synchronized进行了多线程并发访问的同步
     * 因为task实际上是以java线程的方式,在一个CoarseGrainedExecutorBackend进程内并发运行的
     * 如果在执行业务逻辑的时候,要访问一些共享的资源,那么就可能会出现多线程并发访问安全问题
     * 所以,spark在这里选择进行了多线程并发访问的同步(synchronized),因为在这里面访问了诸如currentFiles等等这些共享资源
     */
    synchronized {
      // Fetch missing dependencies
      //遍历要拉取的文件
      for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
        logInfo("Fetching " + name + " with timestamp " + timestamp)
        // Fetch file with useCache mode, close cache for local mode.
         //通过网络通讯,远程拉取文件
        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
          env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
        currentFiles(name) = timestamp
      }
      //遍历拉取的jar
      for ((name, timestamp) <- newJars) {
        //判断时间戳,要求jar当前时间戳小于目标时间戳
        val localName = new URI(name).getPath.split("/").last
        val currentTimeStamp = currentJars.get(name)
          .orElse(currentJars.get(localName))
          .getOrElse(-1L)
        if (currentTimeStamp < timestamp) {
          logInfo("Fetching " + name + " with timestamp " + timestamp)
          // Fetch file with useCache mode, close cache for local mode.
          //拉取jar文件
          Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
            env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
          currentJars(name) = timestamp
          // Add it to our class loader
          val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
          if (!urlClassLoader.getURLs().contains(url)) {
            logInfo("Adding " + url + " to class loader")
            urlClassLoader.addURL(url)
          }
        }
      }
    }
  }

第四步:task的run()方法

  final def run(
      taskAttemptId: Long,
      attemptNumber: Int,
      metricsSystem: MetricsSystem): T = {
    SparkEnv.get.blockManager.registerTask(taskAttemptId)
    // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
    // the stage is barrier.
    /**
     * 创建一个TaskContext,就是task的执行上下文,里面记录了task执行的一些全局性的数据,比如task重试了几次
     * 比如task属于哪个stage,task要处理的是rdd的哪个partition等等
     */
    val taskContext = new TaskContextImpl(
      stageId,
      stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
      partitionId,
      taskAttemptId,
      attemptNumber,
      taskMemoryManager,
      localProperties,
      metricsSystem,
      metrics)

    context = if (isBarrier) {
      new BarrierTaskContext(taskContext)
    } else {
      taskContext
    }

    TaskContext.setTaskContext(context)
    taskThread = Thread.currentThread()

    if (_reasonIfKilled != null) {
      kill(interruptThread = false, _reasonIfKilled)
    }

    new CallerContext(
      "TASK",
      SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
      appId,
      appAttemptId,
      jobId,
      Option(stageId),
      Option(stageAttemptId),
      Option(taskAttemptId),
      Option(attemptNumber)).setCurrentContext()

    try {
       // 调用抽象方法,runTask()
      runTask(context)
    } catch {
      case e: Throwable =>
        // Catch all errors; run task failure callbacks, and rethrow the exception.
        try {
          context.markTaskFailed(e)
        } catch {
          case t: Throwable =>
            e.addSuppressed(t)
        }
        context.markTaskCompleted(Some(e))
        throw e
    } finally {
      try {
        // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
        // one is no-op.
        context.markTaskCompleted(None)
      } finally {
        try {
          Utils.tryLogNonFatalError {
            // Release memory used by this thread for unrolling blocks
            SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
            SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
              MemoryMode.OFF_HEAP)
            // Notify any tasks waiting for execution memory to be freed to wake up and try to
            // acquire memory again. This makes impossible the scenario where a task sleeps forever
            // because there are no other tasks left to notify it. Since this is safe to do but may
            // not be strictly necessary, we should revisit whether we can remove this in the
            // future.
            val memoryManager = SparkEnv.get.memoryManager
            memoryManager.synchronized { memoryManager.notifyAll() }
          }
        } finally {
          // Though we unset the ThreadLocal here, the context member variable itself is still
          // queried directly in the TaskRunner to check for FetchFailedExceptions.
          TaskContext.unset()
        }
      }
    }
  }

第五步: runTask(context)方法

  /**
   *  task是抽象方法,意味着这个类只是模板类,仅仅封装了一些子类通用的属性和方法,依赖于子类实现它们,来确定具体的功能
   *  前面说过task的有两个子类ShuffleMapTask和ResultTask。有了它们,才能运行定义的算子和逻辑
   */
  def runTask(context: TaskContext): T

第六步:ShuffleMapTask子类的runTask方法
源码地址:org.apache.spark.scheduler.ShuffleMapTask.scala

/**
 * ShuffleMapTask将rdd的元素,切分为多个bucket
 * 基于ShuffleDependency指定的partitioner,默认就是HashPartitioner
 */
private[spark] class ShuffleMapTask(
    stageId: Int,
    stageAttemptId: Int,
    taskBinary: Broadcast[Array[Byte]],
    partition: Partition,
    @transient private var locs: Seq[TaskLocation],
    localProperties: Properties,
    serializedTaskMetrics: Array[Byte],
    jobId: Option[Int] = None,
    appId: Option[String] = None,
    appAttemptId: Option[String] = None,
    isBarrier: Boolean = false)
  extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties,
    serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
  with Logging {

......

  /**
   * ShuffleMapTask的 runTask 有 MapStatus返回值
   */
  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    //对task要处理的数据,做反序列化操作
    /**
     * 多个task在executor中并发运行,数据可能都不在一台机器上,一个stage处理的rdd都是一样的task怎么拿到自己要处理的数据的?
     * 答案:通过broadcast value  广播变量获取
     */
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    var writer: ShuffleWriter[Any, Any] = null
    try {
       // 获取ShuffleManager
      val manager = SparkEnv.get.shuffleManager
      // 从ShuffleManager中获取ShuffleWriter
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      /**
       * 首先调用了,rdd的iterator()方法,并且传入了,当前task要处理哪个partition
       * 所以核心的逻辑,就在rdd的iterator()方法中,在这里,实现了针对rdd的某个partition,执行我们自己定义的算子,或者是函数
       
       * 执行完了我们自己定义的算子、或者函数,就相当于是,针对rdd的partition执行了处理,会有返回的数据
       * 返回的数据,都是通过ShuffleWriter,经过HashPartitioner进行分区之后,写入自己对应的分区bucket
       */
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      /**
       * 最后,返回结果MapStatus,MapStatus里面封装了ShuffleMapTask计算后的数据,数据存储在哪里,其实就是BlockManager的相关信息
       * BlockManager是Spark底层的内存,数据,磁盘数据管理的组件
       */
      writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }
......

}

第七步:rdd.iterator方法

  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
    // cacheManager相关东西
      getOrCompute(split, context)
    } else {
      // 进行rdd partition的计算
      computeOrReadCheckpoint(split, context)
    }
  }

第八步:computeOrReadCheckpoint()方法

  private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointedAndMaterialized) {
      firstParent[T].iterator(split, context)
    } else {
      //抽象方法,找具体实现类,比如MapPartitionsRDD
      compute(split, context)
    }
  }

第九步:MapPartitionsRDD类中的compute(split, context)方法
源码地址:org.apache.spark.rdd.MapPartitionsRDD.scala

  /**
   * 这里,就是针对rdd中的某个partition执行我们给这个rdd定义的算子和函数
   * 这里的f,可以理解为我们自己定义的算子和函数,但是是Spark内部进行了封装的,还实现了一些其他的逻辑
   * 执行到了这里,就是在针对RDD的partition,执行自定义的计算操作,并返回新的rdd的partition数据
   */
  override def compute(split: Partition, context: TaskContext): Iterator[U] =
    f(context, split.index, firstParent[T].iterator(split, context))

第十步:ResultTask的runTask()方法
源码地址: org.apache.spark.scheduler.ResultTask.scala

  override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    // 进行了基本的反序列化
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    // 执行通过rdd的iterator,执行我们定义的算子和函数
    func(context, rdd.iterator(partition, context))
  }

第十一步:第二步中execBackend.statusUpdate方法给Driver发信息汇报自己的状态。告诉Driver,Task已经完成了

// 调用了executor所在的CoarseGrainedExecutorBackend的statusUpdate()方法
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

第十二步:CoarseGrainedExecutorBackend的statusUpdate()方法

// 发送StatusUpdate消息给CoarseGrainedSchedulerBackend(Driver)
  override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
    val msg = StatusUpdate(executorId, taskId, state, data)
    driver match {
      case Some(driverRef) => driverRef.send(msg)
      case None => logWarning(s"Drop $msg because has not yet connected to driver")
    }
  }

第十三步:CoarseGrainedSchedulerBackend的StatusUpdate消息
源码地址:org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.scala

      case StatusUpdate(executorId, taskId, state, data) =>
        // 调用TaskSchedulerImpl#statusUpdate进行更新
        scheduler.statusUpdate(taskId, state, data.value)
        // 如果Task处于完成状态
        if (TaskState.isFinished(state)) {
          // 通过executor id获取ExecutorData
          executorDataMap.get(executorId) match {
            // 如果存在数据
            case Some(executorInfo) =>
              // 则更新executor的cpu核数
              executorInfo.freeCores += scheduler.CPUS_PER_TASK
              // 获取集群中可用的executor列表,发起task
              makeOffers(executorId)
            case None =>
              // Ignoring the update since we don't know about the executor.
              logWarning(s"Ignored task status update ($taskId state $state) " +
                s"from unknown executor with ID $executorId")
          }
        }

第十四步:scheduler.statusUpdate方法

  def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
    var failedExecutor: Option[String] = None
    var reason: Option[ExecutorLossReason] = None
    synchronized {
      try {
        Option(taskIdToTaskSetManager.get(tid)) match {
          case Some(taskSet) =>
            // 判断如果task是lost了,实际上,可能会经常发现task lost了,这就是因为各种各样的原因,执行失败了
            if (state == TaskState.LOST) {
              // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode,
              // where each executor corresponds to a single task, so mark the executor as failed.
              // 移除Executor,将它加入失败队列
              val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException(
                "taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)"))
              if (executorIdToRunningTaskIds.contains(execId)) {
                reason = Some(
                  SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
                removeExecutor(execId, reason.get)
                failedExecutor = Some(execId)
              }
            }
            if (TaskState.isFinished(state)) {
              // 如果task结束了,从内存缓存中移除
              cleanupTaskState(tid)
              taskSet.removeRunningTask(tid)
              // 如果正常结束,也做相应的处理
              if (state == TaskState.FINISHED) {
                taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
              } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
                taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
              }
            }
          case None =>
            logError(
              ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
                "likely the result of receiving duplicate task finished status updates) or its " +
                "executor has been marked as failed.")
                .format(state, tid))
        }
      } catch {
        case e: Exception => logError("Exception in statusUpdate", e)
      }
    }
    // Update the DAGScheduler without holding a lock on this, since that can deadlock
    if (failedExecutor.isDefined) {
      assert(reason.isDefined)
      dagScheduler.executorLost(failedExecutor.get, reason.get)
      backend.reviveOffers()
    }
  }

总结一下,task的运行一开始不是直接调用底层的task的run方法直接处理job–>stage–>taskSet–>task这条路线的task任务的,它是通过分层和分工的思想来完成。task会派生出两个子类ShuffleMapTask和ResultTask分别完成对应的工作,ShuffleMapTask主要是对task所拥有的的RDD的partition做对应的RDD转换工作,ResultTask主要是根据action动作触发,并拉取ShuffleMapTask阶段的结果做进一步的算子和逻辑函数对数据对真正进一步的处理。这两个阶段是通过MapOutputTracker来连接起来的。

猜你喜欢

转载自blog.csdn.net/jiaojiao521765146514/article/details/85336328