Spark源码浅析:Stage划分及提交Task

spark版本

spark 1.6.1

从Rdd的action开始追源码,最后都会到rdd.runJob

      
      
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
      
      
/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark.
*/
def [ T, U: ClassTag](
rdd: RDD[ T],
func: ( TaskContext, Iterator[ T]) => U,
partitions: Seq[ Int],
resultHandler: ( Int, U) => Unit): Unit = {
if (stopped.get()) {
throw new IllegalStateException( "SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo( "Starting job: " + callSite.shortForm)
if (conf.getBoolean( "spark.logLineage", false)) {
logInfo( "RDD's recursive dependencies:n" + rdd.toDebugString)
}
//== Next Step ==
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
progressBar.foreach(_.finishAll())
rdd.doCheckpoint()
}

可以看到接下来调用了dagScheduler.runJob

      
      
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
      
      
// DAGScheduler.scala
def [ T, U](
rdd: RDD[ T],
func: ( TaskContext, Iterator[ T]) => U,
partitions: Seq[ Int],
callSite: CallSite,
resultHandler: ( Int, U) => Unit,
properties: Properties): Unit = {
val start = System.nanoTime
//== Next Step ==
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded =>
logInfo( "Job %d finished: %s, took %f s".format
(waiter.jobId, callSite.shortForm, ( System.nanoTime - start) / 1e9))
case JobFailed(exception: Exception) =>
logInfo( "Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, ( System.nanoTime - start) / 1e9))
// SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
val callerStackTrace = Thread.currentThread().getStackTrace.tail
exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
throw exception
}
}
def submitJob[ T, U](
rdd: RDD[ T],
func: ( TaskContext, Iterator[ T]) => U,
partitions: Seq[ Int],
callSite: CallSite,
resultHandler: ( Int, U) => Unit,
properties: Properties): JobWaiter[ U] = {
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
throw new IllegalArgumentException(
"Attempting to access a non-existent partition: " + p + ". " +
"Total number of partitions: " + maxPartitions)
}
val jobId = nextJobId.getAndIncrement()
if (partitions.size == 0) {
// Return immediately if the job is running 0 tasks
return new JobWaiter[ U]( this, jobId, 0, resultHandler)
}
assert(partitions.size > 0)
val func2 = func.asInstanceOf[( TaskContext, Iterator[_]) => _]
//== Next Step ==
val waiter = new JobWaiter( this, jobId, partitions.size, resultHandler)
eventProcessLoop.post( JobSubmitted(
jobId, rdd, func2, partitions.toArray, callSite, waiter,
SerializationUtils.clone(properties)))
waiter
}

eventProcessLoop是DAGSchedulerEventProcessLoop类的实例,DAGSchedulerEventProcessLoop类是EventLoop的子类。EventLoop里有一个阻塞队列,post函数往队列里放请求,还有开启了一个线程不断从队列里取请求。
以上代码总结,往队列里放了一个JobSubmitted的请求,然后需要处理JobSubmitted请求了

      
      
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
      
      
// DAGScheduler.scala
private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)
extends EventLoop[ DAGSchedulerEvent]( "dag-scheduler-event-loop") with Logging {
private[ this] val timer = dagScheduler.metricsSource.messageProcessingTimer
/**
* The main event loop of the DAG scheduler.
*/
override def onReceive(event: DAGSchedulerEvent): Unit = {
val timerContext = timer.time()
try {
doOnReceive(event)
} finally {
timerContext.stop()
}
}
private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
//== Next Step ==
case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)
...
...
}
}
/*
* 下面的代码中,调用了newResultStage进行任务的划分,该方法是划分任务的核心方法,划分任务的根据最后一个依赖关系作为开始,
* 通过递归,将每个宽依赖做为切分Stage的依据,切分Stage的过程是流程中的一环,当任务切分完毕后
* 代码继续执行来到submitStage(finalStage)这里开始进行任务提交
*/
private[scheduler] def handleJobSubmitted(jobId: Int,
finalRDD: RDD[_],
func: ( TaskContext, Iterator[_]) => _,
partitions: Array[ Int],
callSite: CallSite,
listener: JobListener,
properties: Properties) {
var finalStage: ResultStage = null
try {
// New stage creation may throw an exception if, for example, jobs are run on a
// HadoopRDD whose underlying HDFS files have been deleted.
//== Next Step 1==
finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)
} catch {
case e: Exception =>
logWarning( "Creating new stage failed due to exception - job: " + jobId, e)
listener.jobFailed(e)
return
}
val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
clearCacheLocs()
logInfo( "Got job %s (%s) with %d output partitions".format(
job.jobId, callSite.shortForm, partitions.length))
logInfo( "Final stage: " + finalStage + " (" + finalStage.name + ")")
logInfo( "Parents of final stage: " + finalStage.parents)
logInfo( "Missing parents: " + getMissingParentStages(finalStage))
val jobSubmissionTime = clock.getTimeMillis()
jobIdToActiveJob(jobId) = job
activeJobs += job
finalStage.setActiveJob(job)
val stageIds = jobIdToStageIds(jobId).toArray
val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
listenerBus.post(
SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
//== Next Step 2==
submitStage(finalStage)
submitWaitingStages()
}

newResultStage函数 通过传入finalRDD最后返回finalStage (stage之间和rdd之间都会有依赖关系, newResultStage函数是 通过rdd之间的依赖关系 划分stage的)
继续看源码

      
      
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
      
      
// DAGScheduler.scala
/**
* Create a ResultStage associated with the provided jobId.
*/
private def newResultStage(
rdd: RDD[_],
func: ( TaskContext, Iterator[_]) => _,
partitions: Array[ Int],
jobId: Int,
callSite: CallSite): ResultStage = {
//== Next Step ==
val 大专栏  Spark源码浅析:Stage划分及提交Taskspan> (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite)
stageIdToStage(id) = stage
updateJobIdStageIdMaps(jobId, stage)
stage
}
private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): ( List[ Stage], Int) = {
//== Next Step ==
val parentStages = getParentStages(rdd, firstJobId)
val id = nextStageId.getAndIncrement()
(parentStages, id)
}

以上代码是为了获取finalRDD对应finalStages的所有依赖父Stage

      
      
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
      
      
// DAGScheduler.scala
/**
* Get or create the list of parent stages for a given RDD. The new Stages will be created with
* the provided firstJobId.
*/
private def getParentStages(rdd: RDD[_], firstJobId: Int): List[ Stage] = {
val parents = new HashSet[ Stage]
val visited = new HashSet[ RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[ RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
//== Next Step == PS.如果是finalRdd的直接宽依赖,那么需要划分stage了
parents += getShuffleMapStage(shufDep, firstJobId)
case _ =>
waitingForVisit.push(dep.rdd)
}
}
}
}
waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) {
visit(waitingForVisit.pop())
}
parents.toList
}
/**
* Get or create a shuffle map stage for the given shuffle dependency's map side.
*/
private def getShuffleMapStage(
shuffleDep: ShuffleDependency[_, _, _],
firstJobId: Int): ShuffleMapStage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
// We are going to register ancestor shuffle dependencies
//== Next Step == PS.找到finalRdd的直接宽依赖rdd 对应的祖先宽依赖,进行注册
getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId)
}
// Then register current shuffleDep
val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
}

以上代码是找到finalRdd的直接父Stage。并在找到直接宽依赖rdd的时候,对先找到这些rdd的所有祖先宽依赖,再对这些祖先宽依赖进行注册(这里也会划分stage)。
下面的代码是如果找到rdd的所有祖先宽依赖(直接或间接)

      
      
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
      
      
// DAGScheduler.scala
/** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ ShuffleDependency[_, _, _]] = {
val parents = new Stack[ ShuffleDependency[_, _, _]]
val visited = new HashSet[ RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[ RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
parents.push(shufDep)
}
case _ =>
}
waitingForVisit.push(dep.rdd)
}
}
}
waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) {
visit(waitingForVisit.pop())
}
parents
}

下面举一个例子来说明划分stage的过程

图说明:12表示finalRdd,蓝色表示宽依赖,黑色表示窄依赖
1、newResultStage函数 入参12号finalRdd, 通过函数getParentStagesAndId得到finalRdd的直接宽依赖9,10,4,5号Rdd在的Stage。并且注册9,10,4,5号Rdd在的Stage对应的所有直接或间接Stage(通过getParentStagesAndId下的getShuffleMapStage函数)
2、getShuffleMapStage函数 入参(举例:10号Rdd在的Stage), 通过函数getAncestorShuffleDependencies得到10号Rdd的所有直接或间接宽依赖(宽依赖有2号和3号),并通过函数newOrUsedShuffleStage进行注册2号和3号成为新的Stage
最后划分的Stage如下图所示:

回到DAGScheduler.handleJobSubmitted函数
以上部分完成了handleJobSubmitted函数的newResultStage步骤(划分Stage),函数还有submitStage步骤

      
      
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
      
      
/** Submits stage, but first recursively submits any missing parents. */
//递归的方式提交stage
private def submitStage(stage: Stage) {
val jobId = activeJobForStage(stage)
if (jobId.isDefined) {
logDebug( "submitStage(" + stage + ")")
if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
//== Next Step 1 ==
val missing = getMissingParentStages(stage).sortBy(_.id) //获取未提交过的直接宽依赖
logDebug( "missing: " + missing)
if (missing.isEmpty) {
logInfo( "Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
//== Next Step 2 ==
submitMissingTasks(stage, jobId.get) //提交任务(parentStage都提交过了)
} else {
for (parent <- missing) {
//提交 没有提交过的parentStage
submitStage(parent)
}
waitingStages += stage //记录待提交的Stages
}
}
} else {
abortStage(stage, "No active job for stage " + stage.id, None)
}
}
private def getMissingParentStages(stage: Stage): List[ Stage] = {
val missing = new HashSet[ Stage]
val visited = new HashSet[ RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[ RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
val rddHasUncachedPartitions = getCacheLocs(rdd).contains( Nil)
if (rddHasUncachedPartitions) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
if (!mapStage.isAvailable) {
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
waitingForVisit.push(narrowDep.rdd)
}
}
}
}
}
waitingForVisit.push(stage.rdd)
while (waitingForVisit.nonEmpty) {
visit(waitingForVisit.pop())
}
missing.toList
}

继续以上面的图为例,submitStage(Stage8) => getMissingParentStages(Stage8) 得到 Stage4, Stage5, Stage6, Stage7
再递归提交submitStage(Stage4), submitStage(Stage5),submitStage(Stage6),submitStage(Stage7)…
最后能调用submitMissingTasks函数的有Stage1, Stage2, Stage3, Stage4, Stage5
waitingStages记录待提交有Stage6, Stage7, Stage8.最后通过DAGScheduler.handleJobSubmitted的submitWaitingStages处理这些待提交的Stage

      
      
1
2
3
4
5
6
7
8
9
10
11
12
13
      
      
private def submitWaitingStages() {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logTrace( "Checking for newly runnable parent stages")
logTrace( "running: " + runningStages)
logTrace( "waiting: " + waitingStages)
logTrace( "failed: " + failedStages)
val waitingStagesCopy = waitingStages.toArray
waitingStages.clear()
for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) {
submitStage(stage)
}
}

先告一段落..

猜你喜欢

转载自www.cnblogs.com/sanxiandoupi/p/11698686.html