Spark源码解读之Executor以及Task工作原理剖析

前一篇文章中主要讲述了TaskScheduler发送TaskSet中的task到executor中执行,那么,本篇文章接着上文的讲述看看executor的工作原理以及task是如何执行的。

首先来看看executor的工作流程:

executor会启动一个后台进程CoarseGrainedExecutorBackend,首先它会向driver发送RegisterExecutor消息注册executor,注册成功之后,driver返回RegisterExecutor消息,CoarseGrainedExecutorBackend接收到消息之后会创建一个executor对象,然后调用executor的launchTask方法开始执行task。

源码如下:

private[spark] class CoarseGrainedExecutorBackend(
    driverUrl: String,
    executorId: String,
    hostPort: String,
    cores: Int,
    userClassPath: Seq[URL],
    env: SparkEnv)
  extends Actor with ActorLogReceive with ExecutorBackend with Logging {

  Utils.checkHostPort(hostPort, "Expected hostport")

  var executor: Executor = null
  var driver: ActorSelection = null

  override def preStart() {
    logInfo("Connecting to driver: " + driverUrl)
    driver = context.actorSelection(driverUrl)
    //向Driver注册executor
    driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls)
    context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
  }
    ......

  override def receiveWithLogging = {
    //返回RegisterExecutor消息并且启动executor
    case RegisteredExecutor =>
      logInfo("Successfully registered with driver")
      val (hostname, _) = Utils.parseHostPort(hostPort)
      executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)

    case RegisterExecutorFailed(message) =>
      logError("Slave registration failed: " + message)
      System.exit(1)

      //启动任务
    case LaunchTask(data) =>
      if (executor == null) {
        logError("Received LaunchTask command but executor was null")
        System.exit(1)
      } else {
        val ser = env.closureSerializer.newInstance()
        val taskDesc = ser.deserialize[TaskDescription](data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
          taskDesc.name, taskDesc.serializedTask)
      }
    .......
  }

在CoarseGrainedExecutorBackend中的LaunchTask中实际上调用了executor的launchTask方法,在这个方法中,实际上回创建一个TaskRunner(实际上是一个实现了java Runnable接口的一个线程),然后将这个taskRunner放入java线程池中进行调度,其源码如下:

 def launchTask(
      context: ExecutorBackend,
      taskId: Long,
      attemptNumber: Int,
      taskName: String,
      serializedTask: ByteBuffer) {
    //创建TaskRunner
    val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
      serializedTask)
    //将当前线程放入缓冲池
    runningTasks.put(taskId, tr)
    //放入线程池
    threadPool.execute(tr)
  }

因此在executor中,针对每一个task会创建一个taskRunner,然后放入线程池中使用,那么,我们可以进行taskRunner中看看task是怎么运行的。

TaskRunner是Executor的一个内部类,它实际上一个java 的线程类,那么在它的run方法中主要是task的运行原理,源码如下:

override def run() {
      val deserializeStartTime = System.currentTimeMillis()
      Thread.currentThread.setContextClassLoader(replClassLoader)
      val ser = env.closureSerializer.newInstance()
      logInfo(s"Running $taskName (TID $taskId)")
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      var taskStart: Long = 0
      startGCTime = gcTime

      try {
        //反序列化任务文件以及所需要的jar,并且上传
        val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
        updateDependencies(taskFiles, taskJars)
        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)

        // If this task has been killed before we deserialized it, let's quit now. Otherwise,
        // continue executing the task.
        if (killed) {
          // Throw an exception rather than returning, because returning within a try{} block
          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
          // for the task.
          throw new TaskKilledException
        }

        attemptedTask = Some(task)
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        //使用mapOutputTracker跟踪器追踪stage的map中间结果的输出位置
        env.mapOutputTracker.updateEpoch(task.epoch)

        // Run the actual task and measure its runtime.
        taskStart = System.currentTimeMillis()
        val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
        val taskFinish = System.currentTimeMillis()
        ......

run方法中调用了task类的的run方法,这个run方法中就是对于当前task的执行:

 /**
   * Called by Executor to run this task.
   *
   * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
   * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
   * @return the result of the task
   */
  final def run(taskAttemptId: Long, attemptNumber: Int): T = {
    //创建TaskContext,task的上下文
    //里面记录了task的一些全局性数据,比如task重试了几次,task属于哪个stage,task要处理的是哪个rdd的哪个partition等等
    context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
      taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
    TaskContextHelper.setTaskContext(context)
    context.taskMetrics.setHostname(Utils.localHostName())
    taskThread = Thread.currentThread()
    if (_killed) {
      kill(interruptThread = false)
    }
    try {
      //调用抽象方法
      runTask(context)
    } finally {
      context.markTaskCompleted()
      TaskContextHelper.unset()
    }
  }
   /**
    * 调用到了抽象方法,意味着这个类只是一个模板类或者抽象父类
    * 仅仅封装一些子类通用的数据和操作
    * 而关键的操作,都要依赖于子类的实现
    * Task的子类,shuffleMapTask,ResultTask
    * @param context
    * @return
    */
  def runTask(context: TaskContext): T

在前面的文章中的stage划分算法剖析中,我们知道了对于中间的stage会创建一个shuffleMapTask,只有最后一个stage才会创建一个ResultTask,那么在上面的源代码中调用了抽象方法,因此会分别调用task的这两子类中的runTask方法,shuffleMapTask会按照ShuffleDependency指定的partitioner将rdd分割为多个bucket。

实际上在shuffleMapTask中的run方法中会使用shuffleManager的shuffleWriter将数据分区之后写入对应的分区中,在所有的操作完了之后会返回一个MapStatus给DAGScheduler。其源码如下:

 override def runTask(context: TaskContext): MapStatus = {
    //使用广播变量反序列化RDD
    // Deserialize the RDD using the broadcast variable.
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      //获取ShuffleManager,从ShuffleManager中获取ShuffleWriter
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      //首先调用rdd的iterator方法,并且传入了当前task要处理那个partition,然后执行我们定义的函数
      //处理返回的数据都是通过ShuffleWriter,经过HashPartitioner进行分区之后,写入了自己对应的bucket
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      //最后返回结果,MapStatus
      //MapStatus里面封装了ShffleMapTask计算后的数据,存储在哪里,其实就是BlockManager的信息
      //BlockManager是spark底层内存,数据,磁盘数据管理的组件
      return writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

而相比较ShuffleMapTask,ResultTask比较简单,它主要会对ShuffleMapTask的中间输出结果来执行shuffle操作以及我们定义的算子和函数,因此它可能会去MapOutputTracker中拉取输出的中间数据,源码如下:

  override def runTask(context: TaskContext): U = {
    //通过广播变量反序列化RDD
    // Deserialize the RDD and the func using the broadcast variables.
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

    metrics = Some(context.taskMetrics)
    func(context, rdd.iterator(partition, context))
  }

然后会进行shuffle操作将task的执行结果放入缓存或者写入磁盘中,在后面的文章中主要会介绍shuffle的原理。

猜你喜欢

转载自blog.csdn.net/qq_37142346/article/details/81395725