spark源码《二》Task

上篇主要写了spark的基本数据抽象RDD,spark源码《一》RDD,这篇主要是spark的最小执行单元Task,

当发生shuffle依赖时,会切分stage,每个stage的task数量,由该stage最后rdd的partition数量决定,

task有两种,shufflemaptask和resulttask,resulttask是finalstage,也就是需要将结果返回给driver的stage,而

shufflemaptask无需将结果返回,需要将结果shuffle后传给后面的shufflemaptask或者resulttask,类似与

mapreduce的mapper,shuffle完后将数据传给reducer。

1.Task

class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) 
extends Serializable


abstract class Task[T] extends Serializable {
  def run(id: Int): T//运行该task
  def preferredLocations: Seq[String] = Nil//获取优先位置
  def generation: Option[Long] = None//当fetch数据失败时,该值+1
}

TaskContext有三个类参数,​​​​分别为:
stageId,表示该task属于哪个stage

splitId,RDD的partition

attemptId,运行Id

2.DAGTask

为Task的子类

abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] {
  //getGeneration通过请求worker或master来获取当前的generation数
val gen = SparkEnv.get.mapOutputTracker.getGeneration

  override def generation: Option[Long] = Some(gen)
}

3.ResultTask

为DAGTask的子类

class ResultTask[T, U](
    runId: Int,
    stageId: Int, 
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    val partition: Int, 
    locs: Seq[String],
    val outputId: Int)
  extends DAGTask[U](runId, stageId) {
  
  val split = rdd.splits(partition)//获取分区


  override def run(attemptId: Int): U = {

    val context = new TaskContext(stageId, partition, attemptId)
//实例化一个TaskContext对象

    func(context, rdd.iterator(split))
//返回一个方法,参数为TaskContext对象,rdd的某个分区数据
  }

  override def preferredLocations: Seq[String] = locs

  override def toString = "ResultTask(" + stageId + ", " + partition + ")"
}

可以看到ResultTask的run()方法返回的是一个用于计算某个rdd分区方法,方法可以是count(),take()等,直接计算出结果

4.ShuffleMapTask

class ShuffleMapTask(
    runId: Int,
    stageId: Int,
    rdd: RDD[_], 
    dep: ShuffleDependency[_,_,_],
    val partition: Int, 
    locs: Seq[String])
  extends DAGTask[String](runId, stageId)
  with Logging {
  
  val split = rdd.splits(partition)  

  override def run (attemptId: Int): String = {
    val numOutputSplits = dep.partitioner.numPartitions //获取分区数
    
    //将参数强转
    val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
    val partitioner = dep.partitioner.asInstanceOf[Partitioner]
    
    //创建一个长度为分区数的数组
    val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])

    for (elem <- rdd.iterator(split)) {//遍历rdd分区的元素
   
      val (k, v) = elem.asInstanceOf[(Any, Any)]
      var bucketId = partitioner.getPartition(k)//决定该元素去往后一个rdd的哪个分区
      val bucket = buckets(bucketId)//取数组下标为bucketId的数据
      var existing = bucket.get(k)//通过key获取value

      if (existing == null) {//如果为空
        bucket.put(k, aggregator.createCombiner(v))//新建累加器,将k,v放入
      } else {
        bucket.put(k, aggregator.mergeValue(existing, v))//否则,直接将v放入累加器
      }
    }
    val ser = SparkEnv.get.serializer.newInstance()
    for (i <- 0 until numOutputSplits) {
     //创建文件,准备写数据
      val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i)
       
    //创建写入文件数据的流
      val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
    
      out.writeObject(buckets(i).size)//先写入每个数组(分区)的元素数

      val iter = buckets(i).entrySet().iterator()
      while (iter.hasNext()) {//遍历,将数组数据写往对应的分区文件
        val entry = iter.next()
        out.writeObject((entry.getKey, entry.getValue))
      }
      
      out.close()
    }
    return SparkEnv.get.shuffleManager.getServerUri
    //返回uri,等待后续task拉取数据
  }

  override def preferredLocations: Seq[String] = locs

  override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}

可以看到ShuffleMapTask返回的是一个uri,等待后续的task拉取数据,

该方法主要分为两步,

第一步遍历rdd分区中数据,根据partitioner(可查看上一篇了解分区器的两种类型)决定分区中数据去往后续rdd的哪个分区,

并将去往同一分区的数据写入下标相同的数组

第二步遍历数组,按分区写入对应的文件(当后续rdd有n个分区时,会写n个文件),返回uri,等待后续节点拉取数据

如图所示,假设父RDD有4个分区,子RDD由3个分区,当父RDD第一个分区,调用run()方法时,

会先创建一个长度为3的数组,遍历分区元素,通过partitioner决定元素去往下标为0或1或2的位置,然后写入数组,

接下来创建3个文件,将数据数据写入对应的文件,返回uri,等待fetch

猜你喜欢

转载自blog.csdn.net/zhaolq1024/article/details/82662125