从一个小程序看 spark core 源码

从一个小程序看 spark core 源码
2017-8-29
liyakun

从一个小程序的运行入手,看一下整个spark core 底层的流程。

1. 小程序-GroupByTest

这是一个使用Spark进行GroupBy的小程序。

程序属于Spark使用代码示例中的一个,在Spark中的源码的位置是:spark/spark-branch-2.0/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import java.util.Random

import org.apache.spark.sql.SparkSession

/**
* Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
*/
object GroupByTest {
def main(args: Array[String]) {
val spark = SparkSession.builder.appName("GroupBy Test").getOrCreate()

val numMappers = if (args.length > 0) args(0).toInt else 2
val numKVPairs = if (args.length > 1) args(1).toInt else 1000
val valSize = if (args.length > 2) args(2).toInt else 1000
val numReducers = if (args.length > 3) args(3).toInt else numMappers

val pairs1 = spark.sparkContext.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
val arr1 = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
}
arr1
}.cache()
// Enforce that everything has been calculated and in cache
pairs1.count()

println(pairs1.groupByKey(numReducers).count())

spark.stop()
}
}

2. 切入点

spark切入点是创建sparkSession的部分代码:SparkSession.builder.appName(“GroupBy Test”).getOrCreate()

在Spark的早期版本,sparkContext是进入Spark的切入点。我们都知道RDD是Spark中重要的API,然而它的创建和操作得使用sparkContext提供的API;对于RDD之外的其他东西,我们需要使用其他的Context。比如对于流处理来说,我们得使用StreamingContext;对于SQL得使用sqlContext;而对于hive得使用HiveContext。然而DataSet和Dataframe提供的API逐渐称为新的标准API,我们需要一个切入点来构建它们,所以在 Spark 2.0中引入了一个新的切入点(entry point):SparkSession

  SparkSession实质上是SQLContext和HiveContext的组合(未来可能还会加上StreamingContext),所以在SQLContext和HiveContext上可用的API在SparkSession上同样是可以使用的。SparkSession内部封装了sparkContext,所以计算实际上是由sparkContext完成的。

SparkSession的源码路径为:spark-branch-2.0/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

查看其中的getOrCreate()方法的具体逻辑,发现,它会先查检是否存在全局的session或者运行中的session,如果都没有,那么先使用用户提供的所有参数创建一个sparkConf,再通过这个sparkConf创建一个sparkContext,然后再使用这个sparkContext创建一个SparkSession。

SparkSession对象中,会包含一个sqlContext,并且可以通过enableHiveSupport方法,来支持对HiveContext的使用。

3. 第一个RDD的创建

spark的session对象中有一个SparkContext的成员对象,通过这个SparkContext对象的parallelize方法,可以产生第一个RDD。

下面是这个方法的源码:

1
2
3
4
5
6
7
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/SparkContext.scala
def parallelize[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}

从上面的源码中,可以看到,这个parallelize方法,最终会产生并返回一个ParallelCollectionRDD类型的RDD对象。

4. RDD的转换

ParallelCollectionRDD类里面并没有对flatMap()方法进行实现,但是它的父类RDD里面实现了这个方法。

.flatMap()是RDD的一个转换操作。

1
2
3
4
5
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/RDD.scala
def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = withScope {
val cleanF = sc.clean(f)
new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.flatMap(cleanF))
}

flatMap这个函数,首先把检查一下函数是否是可以序列化的,然后产生出来一个新的MapPartitionsRDD,并返回这个新的MapPartitionsRDD。接下来,再看一下MapPartitionsRDD的源代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
var prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {

override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None

override def getPartitions: Array[Partition] = firstParent[T].partitions

override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))

override def clearDependencies() {
super.clearDependencies()
prev = null
}
}

MapPartitionsRDD继承了RDD这个抽象类,并且实现了自己的partitioner,getPartitions,compute,clearDependencies。这个代码里面最重要的是重写了compute方法。

大家都知道Spark是lazy的计算模型,这个RDD的转换,其实本身并不会立即产生真实的计算,但是RDD的每次转换,都通过把自己的compute叠加起来了,等将来真的需要计算的时候,这些叠加在一起的函数就会开始层层的计算,这个是后话了。

5. RDD持久化

示例小程序中的.cache()方法是将RDD进行持久化的操作,是属于RDD这个抽象类实现的方法。下面咱们看一下它的源码:

1
2
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/RDD.scala
def cache(): this.type = persist()

5.1 定义缓存

cache调用了persist()方法,然后再看一下persist()方法:

1
2
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/RDD.scala
def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)

在简单的几个函数调用之后,会在其中一个persist方法里面,为这个RDD的storageLevel进行赋值,storageLevel = newLevel。

再之后,经过经过几个简单的函数调用,最终是调用了SparkContext类的persistRDD方法。这个方法非常简单,就是把rdd以自己的id为key,以自己为value放入到一个名字叫做persistentRdds的map里面。

1
2
3
4
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/SparkContext.scala
private[spark] def persistRDD(rdd: RDD[_]) {
persistentRdds(rdd.id) = rdd
}

显然,整个过程,并没有真正的进行RDD的缓存操作,只是记录了已经被标记为persist的RDD的引用.

在定义缓存阶段,实际上只做了两件事:

  • 一是设置了rdd的StorageLevel
  • 二是将rdd加到了persistentRdds中并在ContextCleaner中注册

5.2 触发缓存

spark的计算是lazy的,只有在执行action时才真正去计算每个RDD的数据。为了便于理解,这里先提前介绍一下这方面的工作内容。

在Spark的Executor去执行task的计算时,会调用到RDD的iterator方法(详细的代码,可以参见:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala,和,spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala,这两个类里面的runTask方法),来对RDD的指定partition进行计算。

下面,仔细的看一下RDD的iterator方法:

1
2
3
4
5
6
7
8
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/RDD.scala
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
}

首先是判断一下storageLevel是否不为NONE,在之前的定义阶段已经设定为了不为NONE了,因此,继续深入到getOrCompute方法里面:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/RDD.scala
private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
case Left(blockResult) =>
if (readCachedBlock) {
val existingMetrics = context.taskMetrics().inputMetrics
existingMetrics.incBytesRead(blockResult.bytes)
new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
override def next(): T = {
existingMetrics.incRecordsRead(1)
delegate.next()
}
}
} else {
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
}
case Right(iter) =>
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
}
}

首先是获取到唯一的blockid,计算的方法非常简单,就是”rdd_” + rddId + “_” + splitIndex。

然后是调用getOrElseUpdate函数,再然后是针对这个函数的返回值进行特殊处理,先看一下getOrElseUpdate函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
def getOrElseUpdate[T](
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[T],
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
get(blockId) match {
case Some(block) =>
return Left(block)
case _ =>
// Need to compute the block.
}
// Initially we hold no locks on this block.
doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
case None =>
// doPut() didn't hand work back to us, so the block already existed or was successfully
// stored. Therefore, we now hold a read lock on the block.
val blockResult = getLocalValues(blockId).getOrElse {
// Since we held a read lock between the doPut() and get() calls, the block should not
// have been evicted, so get() not returning the block indicates some internal error.
releaseLock(blockId)
throw new SparkException(s"get() failed for block $blockId even though we held a lock")
}
// We already hold a read lock on the block from the doPut() call and getLocalValues()
// acquires the lock again, so we need to call releaseLock() here so that the net number
// of lock acquisitions is 1 (since the caller will only call release() once).
releaseLock(blockId)
Left(blockResult)
case Some(iter) =>
// The put failed, likely because the data was too large to fit in memory and could not be
// dropped to disk. Therefore, we need to pass the input iterator back to the caller so
// that they can decide what to do with the values (e.g. process them without caching).
Right(iter)
}
}

首先是通过get(blockId)来看这个blockid是否已经存在了,get会先调用getLocal在本地获取,如果本地没有则调用getRemote去远程寻找,如果查到了,就返回Left(block)。如果没有查到,那么就使用doPutIterator方法,把block放入进去。

6. RDD的Action

Action会真正的启动一个Spark任务,整个计算从Driver触发,然后安排好各个Executor的计算内容,然后把Executor都拉进来,进入真正的分布式计算,最终把结果收集到Driver上面,返回给用户。

6.1 Driver创建TaskSet

再回到小程的源码里面,pairs1.count()是一个Action类型的操作,pairs1是MapPartitionsRDD类型的,这个类里面并没有count方法,但是它的父类RDD类里面有,下面是它的源码实现。

1
2
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\rdd\RDD.scala
def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum

可以看到,程序调用了SparkContext的runJob方法,之后层层调用各种runJob方法,一直到spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\DAGScheduler.scala里面的runJob方法。

在这个方法里面,会调用submitJob方法,然后在里面会new一个JobWaiter对象,用来等待此次Job的执行结束。然后就把关于这个Job的所有信息提交给eventProcessLoop队列,等待执行和结束,具体的这块代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\DAGScheduler.scala
def submitJob[T, U](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: CallSite,
resultHandler: (Int, U) => Unit,
properties: Properties): JobWaiter[U] = {
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
throw new IllegalArgumentException(
"Attempting to access a non-existent partition: " + p + ". " +
"Total number of partitions: " + maxPartitions)
}

val jobId = nextJobId.getAndIncrement()
if (partitions.size == 0) {
// Return immediately if the job is running 0 tasks
return new JobWaiter[U](this, jobId, 0, resultHandler)
}

assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, callSite, waiter,
SerializationUtils.clone(properties)))
waiter
}

此时,eventProcessLoop会触发一个onReceive函数,里面会通过dagScheduler.handleJobSubmitted方法来处理JobSubmitted事件。

在dagScheduler.handleJobSubmitted中,第一个要处理的,就是创建能产生最终结果RDD的Stage:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\DAGScheduler.scala
private[scheduler] def handleJobSubmitted(jobId: Int,
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
callSite: CallSite,
listener: JobListener,
properties: Properties) {
var finalStage: ResultStage = null
try {
// New stage creation may throw an exception if, for example, jobs are run on a
// HadoopRDD whose underlying HDFS files have been deleted.
finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
} catch {
case e: Exception =>
logWarning("Creating new stage failed due to exception - job: " + jobId, e)
listener.jobFailed(e)
return
}

val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job %s (%s) with %d output partitions".format(
job.jobId, callSite.shortForm, partitions.length))
logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))

val jobSubmissionTime = clock.getTimeMillis()
jobIdToActiveJob(jobId) = job
activeJobs += job
finalStage.setActiveJob(job)
val stageIds = jobIdToStageIds(jobId).toArray
val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
listenerBus.post(
SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
submitStage(finalStage)
}

然后,在这个方法里面,会创建一个ActiveJob对象,对把这个对象跟finalStage绑定起来,最后,调用了一个方法submitStage(finalStage)。然后就进入了submitStage(finalStage)方面里面,源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\DAGScheduler.scala
private def submitStage(stage: Stage) {
val jobId = activeJobForStage(stage)
if (jobId.isDefined) {
logDebug("submitStage(" + stage + ")")
if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing.isEmpty) {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage, jobId.get)
} else {
for (parent <- missing) {
submitStage(parent)
}
waitingStages += stage
}
}
} else {
abortStage(stage, "No active job for stage " + stage.id, None)
}
}

submitStage方法会被递归的调用,初始值是最终的结果Stage,最终的目标是提交上所有的Stage。注意里面的getMissingParentStages方法,这个方法的的目标是为了找到为了获取当前的Stage,所依赖的前面的Stage,源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\DAGScheduler.scala
private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
if (rddHasUncachedPartitions) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
if (!mapStage.isAvailable) {
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
waitingForVisit.push(narrowDep.rdd)
}
}
}
}
}
waitingForVisit.push(stage.rdd)
while (waitingForVisit.nonEmpty) {
visit(waitingForVisit.pop())
}
missing.toList
}

getMissingParentStages方法里面,以当前要处理的stage的rdd为根,依托一个stack,以一个类似于按层遍历的方式,把当前的rdd依赖的所有rdd遍历一遍。但是有一点与按层遍历不同的是,这里面,如果依赖的类型是ShuffleDependency的话,那么将不再进行更深层的遍历,会在这里增加一个Stage,并把这些增加的所有的Stage收集起来,用来返回。

在submitStage函数中,收到getMissingParentStages返回的Stage集之后,会对每个返回的结果,再次递归的调用submitStage函数,如果getMissingParentStages返回为空的话,那么意味着此时已经走到了Stage的最前面,这时,会调用submitMissingTasks方法,这个一会儿再说。

先总结一下,Stage的划分过程是,从最终结果Stage开始,依次往前推,遇到ShuffleDependency就产生一个新的Stage。

下面沿着刚刚的线继续往下走,如果getMissingParentStages返回为空的话(在递归的过程中可能会有多个getMissingParentStages的返回为空的情况),那么意味着此时已经走到了Stage的最前面,这时,会调用submitMissingTasks方法。然后,咱们进入submitMissingTasks方法。

不过,首先有一点需要了解一下,那就是在Spark内部,只有两种类型的Stage,那就是ShuffleMapStage和ResultStage:

  • ShuffleMapStage
    • 这种Stage是以Shuffle为输出边界
    • 其输入边界可以是从外部获取数据,也可以是另一个ShuffleMapStage的输出
    • 其输出是另一个Stage的开始
    • ShuffleMapStage的最后Task就是ShuffleMapTask
    • 在一个Job里可能有该类型的Stage,也可以能没有该类型Stage。
  • ResultStage
    • 这种Stage是直接输出结果
    • 其输入边界可以是从外部获取数据,也可以是另一个ShuffleMapStage的输出
    • ResultStage的最后Task就是ResultTask
    • 在一个Job里必定有该类型Stage。

submitMissingTasks方法里面,会考虑数据的本地性,为每个Task选择自己计算的最好的位置,这个以后再说。先假定计算的位置已经选好,接下来会通过序列化产生task所需要的二进制文件,并通过sc.broadcast(taskBinaryBytes)广播出去。然后,如果是ShuffleMapStage会产生ShuffleMapTask类型的TaskSet,如果是ResultStage会产生ResultTask类型的TaskSet。最终,通过如下的方法,提交TaskSet给taskScheduler。

1
2
3
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\DAGScheduler.scala
taskScheduler.submitTasks(new TaskSet(
tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))

下面深入到TaskScheduler的具体代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
val stage = taskSet.stageId
val stageTaskSets =
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.stageAttemptId) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
}
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
"and have sufficient resources")
} else {
this.cancel()
}
}
}, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
}
hasReceivedTask = true
}
backend.reviveOffers()
}

TaskSchedulerImpl的submitTasks方法首先创建TaskSetManager,TaskSetManager负责管理TaskSchedulerImpl中一个单独TaskSet,跟踪每一个task,如果task失败,负责重试task直到达到task重试次数的最多次数。并且通过延迟调度来执行task的位置感知调度。接下来,会把manager添加到schedulableBuilder里面。

由schedulableBuilder决定调度顺序,schedulableBuilder的类型是 SchedulerBuilder,SchedulerBuilder是一个trait,有两个实现FIFOSchedulerBuilder和 FairSchedulerBuilder,并且默认采用的是FIFO方式。

而schedulableBuilder的创建是在SparkContext创建SchedulerBackend和TaskScheduler后调用TaskSchedulerImpl的初始化方法进行创建的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
def initialize(backend: SchedulerBackend) {
this.backend = backend
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
schedulingMode match {
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
new FairSchedulableBuilder(rootPool, conf)
case _ =>
throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode")
}
}
schedulableBuilder.buildPools()
}

schedulableBuilder是TaskScheduler中一个重要成员,他根据调度策略决定了TaskSetManager的调度顺序。

接下来,回归submitTasks方法,最后一行,调用了SchedulerBackend的riviveOffers方法对Task进行调度,决定task具体运行在哪个Executor中。

1
2
3
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

backend.reviveOffers()

其中,backend是SchedulerBackend的实例,这个对象是在SparkContext实例创建时,通过createTaskScheduler方法创建出的,在这个方法里面,会根据master的名字的不同,创建出来不同类型的SchedulerBackend,具体的代码位置是:SparkContext类里面的createTaskScheduler方法。在这里咱只深入了解一下master为yarn的情况,master为yarn是一种比较特殊的情况,它需要通过getClusterManager方法来load外部的ExternalClusterManager,一旦load成功,就开始创建scheduler和backend,并为它们进行初始化,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/SparkContext.scala
case masterUrl =>
val cm = getClusterManager(masterUrl) match {
case Some(clusterMgr) => clusterMgr
case None => throw new SparkException("Could not parse Master URL: '" + master + "'")
}
try {
val scheduler = cm.createTaskScheduler(sc, masterUrl)
val backend = cm.createSchedulerBackend(sc, masterUrl, scheduler)
cm.initialize(scheduler, backend)
(backend, scheduler)
} catch {
case se: SparkException => throw se
case NonFatal(e) =>
throw new SparkException("External scheduler cannot be instantiated", e)
}

这两个对象是SparkContext中非常重要的对象。scheduler负责将本地的计算任务进行Stage划分和TaskSet生成。backend负责跟其它的executor交互。

在这里,先继续回归到刚才的话题,看一下backend,在master为yarn的情况下,使用client模式时,backend的类型为YarnClientSchedulerBackend,YarnClientSchedulerBackend里面只有简单的几个覆盖实现,大部分的方法需要深入到它继承的YarnSchedulerBackend里面才能看到,YarnSchedulerBackend里面也没有reviveOffers方法的实现,需要再深入一层,到它继承的CoarseGrainedSchedulerBackend方法里面,看到这个类里面确实是有reviveOffers方法:

1
2
3
4
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\cluster\CoarseGrainedSchedulerBackend.scala
override def reviveOffers() {
driverEndpoint.send(ReviveOffers)
}

它调用了driverEndpoint.send(ReviveOffers)方法,这个driverEndpoint是一个NettyRpcEnv的实例,它的类的位置是:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala。

RpcEnv是各个组件之间通信的执行环境,每个节点之间(Driver或者Worker)组件的Endpoint和对应的EndpointRef之间的信息通信和方法调用都是通过RpcEnv作协调,而底层是通过Netty NIO框架实现(Spark早期版本通信是通过Akka,大的文件传输是通过Netty,在2.0.0版本后统一由Netty替换成了Akka,实现了通信传输统一化)

6.2 Executor执行Task

上回讲到,Driver把自己的Task序列化后,通过RPC远程发给Executor。Executor的接收入口在spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala类里面。在接收到消息之后,它会调用receiver方法,里面会判断到是LaunchTask类型的消息,然后在反序列化之后调用executor.launchTask方法加载Task。

1
2
3
4
5
6
7
8
9
10
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\executor\CoarseGrainedExecutorBackend.scala
case LaunchTask(data) =>
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
taskDesc.name, taskDesc.serializedTask)
}

Executor里面的launchTask方法,也比较简单,就是把新建立一个TaskRunner对象用来把Task封装一下,然后记录一下,交给线程池去运行。

1
2
3
4
5
6
7
8
9
10
11
12
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\executor\Executor.scala
def launchTask(
context: ExecutorBackend,
taskId: Long,
attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer): Unit = {
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}

下面,需要重点看一下TaskRunner类里面的run方法,位置是:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/executor/Executor.scala里面的内部类TaskRunner,内部的run方法。

首先是新建了一个taskMemoryManager对象,这个之后再分析,然后execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)给driver发个消息,task要开始了,然后把传过来的task进行反序列化,设置好taskMemoryManager等信息,然后向mapOutputTracker更新一下时间截,紧接着,调用了task本身对象的run方法来获得task的结果,

1
2
3
4
5
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\executor\Executor.scala
val res = task.run(
taskAttemptId = taskId,
attemptNumber = attemptNumber,
metricsSystem = env.metricsSystem)

在继续讨论TaskRunner内部的run方法里面的后续步骤之前,先稍微深入一点,讨论一下这里的task.run方法,针对于不同类型的Task是有不同的逻辑的。

  • ResultTask会返回它的函数的计算结果;
  • ShuffleMapTask会返回MapStatus。
1
2
3
4
5
6
7
8
9
10
11
12
//ResultTask
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\ResultTask.scala
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val deserializeStartTime = System.currentTimeMillis()
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

func(context, rdd.iterator(partition, context))
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
//ShuffleMapTask
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\scheduler\ShuffleMapTask.scala
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val deserializeStartTime = System.currentTimeMillis()
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
}

6.2.1 ShuffleMapTask的结果写入(Shuffle Write)

关于ShuffleMapTask,还有更多的内容要介绍一下。因为ShuffleMapTask是一个Stage的终结,同时它产生的结果,会作为下一个Stage的Shuffle Read的输入,我们有必要关心一下,它的结果的存储方式。

在上面的ShuffleMapTask的runTask方法中,SparkEnv.get.shuffleManager会产生一个SortShuffleManager对象(目前只有这一种shuffleManager),这个类的位置是:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

看一下它的getWriter方法,源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\shuffle\sort\SortShuffleManager.scala
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext): ShuffleWriter[K, V] = {
numMapsForShuffle.putIfAbsent(
handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
val env = SparkEnv.get
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
context,
env.conf)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}
}

可以看到,这里有三种类型的writer,分别是:

  • UnsafeShuffleWriter:在序列化排序模型,当输入的记录被传到shuffle 写入器时会被立即序列化,并且在排序过程中以序列化的格式在缓冲器中。这会减少memory占用和GC开销。在排序过程中,它提供cache-efficient sorter,使用一个8 bytes的指针,把排序转化成了一个指针数组的排序,极大的优化了排序性能。

    • 优点:能极大的减少内存占用和GC开销,提高效率
    • 缺点:需要满足三个条件:序列化支持对序列化值的重定位、依赖没有聚集、生成的输出分区小于16777216个。像reduceByKey这类有aggregate操作的算子是不能使用Unsafe Shuffle,它会退化采用Sort Shuffle。
  • BypassMergeSortShuffleWriter:是带Hash风格的基于Sort的Shuffle机制,为每个Reduce端的任务构建一个输出文件,将输入的每条记录分别写入各自对应的文件中,并在最后将这些基于各个分区的文件合并成一个输出文件。

    • 优点:没有partition内部的排序,在小数据量的情况下会比较快
    • 缺点:假设一个executor有K个核,下游会有R个reduce,同时打开的文件的个数为K*R,在数量量特别大时,R值会变大,导致不可计算。
  • SortShuffleWriter:在map阶段(shuffle write),会按照partition id以及key对记录进行排序,将所有partition的数据写在同一个文件中,该文件中的记录首先是按照partition id排序一个一个分区的顺序排列,每个partition内部是按照key进行排序存放,map task运行期间会顺序写每个partition的数据,并通过一个索引文件记录每个partition的大小和偏移量。这样一来,每个map task一次只开两个文件描述符,一个写数据,一个写索引,大大减轻了Hash Shuffle大量文件描述符的问题,即使一个executor有K个core,那么最多一次性开2K个文件描述符。

    • 优点:能处理任意规模的数据
    • 缺点:在满足上面两个的条件下,计算的效率比上面两个都要低

通过handle的类型来选择对应的writer,那么不禁要问了,如何handle是如何确定的呢?
可以看一下在ShuffledRDD对象内部的getDependencies方法内,会创建ShuffleDependency对象,

1
2
3
4
5
6
7
8
9
10
11
12
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
override def getDependencies: Seq[Dependency[_]] = {
val serializer = userSpecifiedSerializer.getOrElse {
val serializerManager = SparkEnv.get.serializerManager
if (mapSideCombine) {
serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]])
} else {
serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]])
}
}
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
}

ShuffleDependency对象的构造函数里面会创建自己的shuffleHandle,这个创建shuffleHandle的方法是:

1
2
3
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/Dependency.scala
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
shuffleId, _rdd.partitions.length, this)

上面的源码中shuffleManager的实例是SortShuffleManager(目前只有这一种shuffleManager),接下来,继续沿着代码跟进到SortShuffleManager的registerShuffle方法,源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise, buffer map outputs in a deserialized form:
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}
}

在上面的这段代码中,就能看到三种Handle的选择过程了。

  • 首先是在不需要mapSideCombine,并且依赖的Partition的数量小于一个配置值:spark.shuffle.sort.bypassMergeThreshold(此值默认200)时,就返回BypassMergeSortShuffleHandle对象
  • 其次如果同时满足下面的三个条件,返回SerializedShuffleHandle
    • Shuffle序列化支持对序列化值的重定位(KryoSerializer支持)
    • Shuffle依赖没有聚集。因为序列化的数据无法直接做聚集操作。
    • Shuffle生成的输出分区小于16777216个。由于只给partition寻址字段留出了24位的空间。
  • 最后在以上两种情况都不满足时,返回默认的BaseShuffleHandle

6.2.1.1 SortShuffleWriter

然后,再回头看一下getWriter方法。

三种ShufflerWriter中的SortShuffleWriter是最通用的方法,首先看一下如果使用的是SortShuffleWriter,写操作write函数应该是什么样子的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\shuffle\sort\SortShuffleWriter.scala
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)

// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}

这个方法,首先根据是否需要dep.mapSideCombine,使用不同的方式,创建了一个ExternalSorter的对象,然后调用了sorter.insertAll(records)方法,来把所有的record写出(这一步非常重要,一会儿会详细的介绍),然后通过sorter.writePartitionedFile方法把数据写入到data文件中,最后再写好Index文件,最终只产生两个文件。大体的过程如此,下面仔细的分析一下,首先是sorter.insertAll(records)方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\util\collection\ExternalSorter.scala
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined

if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
}

首先是根据之前创建时传入的aggregator来分开:

  • 先看需要Combine的,update是需要的聚合函数,然后通过while循环来遍历每一条记录,在循环中,首先是读取的数目加1,然后得到当前的kv,然后把这个kv插入到map中(这是一个Spark自己写的PartitionedAppendOnlyMap类型的map),插入的规则是,先判断一下这个map中是否已经包含此k,如果包含,那跟通过聚合函数进行一次结果合并,如果不包含,则在map中对此k进行初始化。最后,调用一下maybeSpillCollection函数来确认一下是否需要spill到磁盘上。下面仔细分析一下maybeSpillCollection方法:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\util\collection\ExternalSorter.scala
private def maybeSpillCollection(usingMap: Boolean): Unit = {
var estimatedSize = 0L
if (usingMap) {
estimatedSize = map.estimateSize()
if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
buffer = new PartitionedPairBuffer[K, C]
}
}

if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
}

首先估计map可能需要占用的内存空间,然后通过maybeSpill来判断是否需要spill

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\util\collection\Spillable.scala
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = acquireMemory(amountToRequest)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
shouldSpill = currentMemory >= myMemoryThreshold
}
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
}
shouldSpill
}

这个方法内部,首先是在读取的条目数量为32的倍数并且当前预估的内存使用大于配置spark.shuffle.spill.initialMemoryThreshold(默认是5 * 1024 * 1024),先尝试申请更多的内存(这里的内存是Executor中的用于计算的内存),如果申请不到,或者不够预估的,那么shouldSpill为true。并且,如果当前读取的条目的数量已经比配置值spark.shuffle.spill.numElementsForceSpillThreshold(默认为Long.MaxValue)还要多时,shouldSpill为true。

如果上面得到的shouldSpill为true,那么就会真正的触发一个spill(collection),把当前的map的内容写入到磁盘文件中,把磁盘文件信息追加到spills这个列表中,然后清空自己的内存使用,重新new一个空白的PartitionedAppendOnlyMap。

  • 再回到insertAll里面看不需要combine的情况。它是直接遍历所有记录,每一次循环计数加1,然后得到当前的记录,然后把记录插入到buffer中(这里的buffer是PartitionedPairBuffer类型的,这个数据结构的不同,是是否需要conbine的主要的不同之处)。然后也是判断一下是否需要Spill。整体的过程非常类似,就不再赘述了。

在整个过程结束之后,通过sorter.writePartitionedFile方法把数据写入到data文件中,最后再写好Index文件,最终只产生两个文件。整个效果分析一下就是,在执行的过程中,每隔一段时间,会产生一个临时文件,文件里面是排好序的,然后最后归并一下。在过程之中,最多会同时产生2K个文件,其中K是executor的核数,在过程结束时,会只保留两个文件。

6.2.1.2 UnsafeShuffleWriter

该writer可将数据序列化后写入到堆外内存,只需要按照partitionid对地址进行排序,整个过程不涉及反序列化。

条件:

  1. 使用的序列化类需要支持relocation,是指Serializer可以对已经序列化的对象进行排序,这种排序起到的效果和先对数据排序再序列化一致.目前只能使用kryoSerializer。
  2. 不需要map side aggregate即不能定义aggregator
  3. partition数量不能大于支持的上限(2^24),由于只给partition寻址字段留出了24位的空间

如果使用的是UnsafeShuffleWriter,它的写操作是什么样子的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\UnsafeShuffleWriter.java
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
// Keep track of success so we know if we encountered an exception
// We do this rather than a standard try/catch/re-throw to handle
// generic throwables.
boolean success = false;
try {
while (records.hasNext()) {
insertRecordIntoSorter(records.next());
}
closeAndWriteOutput();
success = true;
} finally {
if (sorter != null) {
try {
sorter.cleanupResources();
} catch (Exception e) {
// Only throw this error if we won't be masking another
// error.
if (success) {
throw e;
} else {
logger.error("In addition to a failure during writing, we failed during " +
"cleanup.", e);
}
}
}
}
}

在上面的代码中,会先对每个记录执行insertRecordIntoSorter方法,下面看一下这个方法的源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\UnsafeShuffleWriter.java
void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
assert(sorter != null);
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serBuffer.reset();
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
serOutputStream.flush();

final int serializedRecordSize = serBuffer.size();
assert (serializedRecordSize > 0);

sorter.insertRecord(
serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
}

在上面的代码中,最重要的是最后一行把序列化的数据插入到ShuffleExternalSorter的实例sorter中,然后继续跟踪这个代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\ShuffleExternalSorter.java
public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
throws IOException {

// for tests
assert(inMemSorter != null);
if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
logger.info("Spilling data because number of spilledRecords crossed the threshold " +
numElementsForSpillThreshold);
spill();
}

growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int required = length + 4;
acquireNewPageIfNecessary(required);

assert(currentPage != null);
final Object base = currentPage.getBaseObject();
final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
Platform.putInt(base, pageCursor, length);
pageCursor += 4;
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
inMemSorter.insertRecord(recordAddress, partitionId);
}

上面的代码中,首先是判断是否需要spill到磁盘中,这个部分内容较多,一会儿再仔细的说。先假设目前还不需要溢写磁盘,继续往下看。

首先是判断一下是否需要为指针数组(这个数组一会儿介绍)申请额外的内存,然后是看一下是否需要申请额外的内存页,再接下来是产生recordAddress,先深入进入,看一下这个地址是怎么产生:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\memory\TaskMemoryManager.java
public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
if (tungstenMemoryMode == MemoryMode.OFF_HEAP) {
// In off-heap mode, an offset is an absolute address that may require a full 64 bits to
// encode. Due to our page size limitation, though, we can convert this into an offset that's
// relative to the page's base offset; this relative offset will fit in 51 bits.
offsetInPage -= page.getBaseOffset();
}
return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
}

public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
}

可以看到,对于使用堆内和堆外是一点区别的,使用堆外的话,由于当前的offsetInPage已经是一个64位的值了,需要跟base值做一个减法,这样可以得到比较小的相对值,而对于使用堆内内存的情况来说,offsetInPage已经是基于base的相对值了,不必考虑这个问题。

紧接着调用的encodePageNumberAndOffset方法里面,把pageNumber左移了51位,offset只保留了低的51位,这样它们拼在了一起。

[13 bit memory page number][51 bit offset in page]

不过,这个不是最终是结果,因为一会儿还要把partition id引入进来。

现在,回到insertRecord函数里面继续往下看,在得到record的地址的起点之后,先把序列化的数据的length放入Platform中,由于是int型的,所以游标+4,紧接着,把序列化的数据放入Platform中,由于已经知道长度是length,所以游标加length。这样就完成了整条记录放到了内存之中。最后一行很重要inMemSorter.insertRecord(recordAddress, partitionId),继续深入到这个函数里面,这个函数的位置是:/Users/yakun/workspace/leap_git_code/spark/spark-branch-2.0/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java

1
2
3
4
5
6
7
8
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\ShuffleInMemorySorter.java
public void insertRecord(long recordPointer, int partitionId) {
if (!hasSpaceForAnotherRecord()) {
throw new IllegalStateException("There is no space for new record");
}
array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId));
pos++;
}

这个函数最主要的功能是,把recored数据在内存中记录的起点和partitionid处理一下,生成一个PackedRecordPointer,然后把这个对象放到数组array里面。下面看一下PackedRecordPointer.packPointer方法:

1
2
3
4
5
6
7
8
9
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\PackedRecordPointer.java
public static long packPointer(long recordPointer, int partitionId) {
assert (partitionId <= MAXIMUM_PARTITION_ID);
// Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
// Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
return (((long) partitionId) << 40) | compressedAddress;
}

这块代码,首先只保留recordPointer的高13位(pagenumber)有值,然后整体右移动24位,这样得到了[24位0][13位pagenumber][27位0]的新pagenumber;然后recordPointer的低27位,并跟刚才的结果取或,得到了[24位0][13位pagenumber][27位page内offset];最后,再把partitionid左移40位,跟刚才的结果取或,得到了最终的结果是:

[24位partition id][13位pagenumber][27位page内offset]

然后这个对象会放到数组array里面。

以上就是每个记录的写入过程了,在回到前面介绍溢写之前,先讨论一下这个内存管理器。

page的总容量为2^13,每个page的寻址范围是2^27,因此总的寻址范围是2^40=1T,注意这相当于是虚拟内存的寻址,相当于是一共有1T个指针,每个指针指向了一个真实的物理地址,由于64位系统的内存都是8字节对齐的,也就是说一个指针就能指向8个字节,也就是说,对于整体的内存的使用能力为1T*8B=8TB。

还有一点非常奇思妙想,就是在实现内存页管理的基础之上,直接在逻辑地址的前24位写上了对应的partitionid,这是为了,未来在对record进行排序时,无需进行反序列化,直接拿这个地址就可以进行比较。

下面再回到insertRecord函数里面,研究一下spill(溢写磁盘)。在insertRecord里面有个判断,如果record的数量大于一个配置值spark.shuffle.spill.numElementsForceSpillThreshold(默认是1G个),就调用spill()方法,下面进入spill()方法内部,这个方法首先是调用了自己的父类的spill,然后父类调用重载函数spill(Long.MAX_VALUE, this),由于刚刚的子类已经覆盖了这个方法,就又回到了刚才的类里面,具体的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\ShuffleExternalSorter.java
public long spill(long size, MemoryConsumer trigger) throws IOException {
if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) {
return 0L;
}

logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
spills.size(),
spills.size() > 1 ? " times" : " time");

writeSortedFile(false);
final long spillSize = freeMemory();
inMemSorter.reset();
// Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
// records. Otherwise, if the task is over allocated memory, then without freeing the memory
// pages, we might not be able to get memory for the pointer array.
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
return spillSize;
}

这个函数最重要的功能就是调用一下writeSortedFile,然后清理一下现场。下面直接讨论一下writeSortedFile函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\ShuffleExternalSorter.java
private void writeSortedFile(boolean isLastFile) throws IOException {

final ShuffleWriteMetrics writeMetricsToUse;

if (isLastFile) {
// We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
writeMetricsToUse = writeMetrics;
} else {
// We're spilling, so bytes written should be counted towards spill rather than write.
// Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
// them towards shuffle bytes written.
writeMetricsToUse = new ShuffleWriteMetrics();
}

// This call performs the actual sort.
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();

// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
// after SPARK-5581 is fixed.
DiskBlockObjectWriter writer;

// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
// data through a byte array. This array does not need to be large enough to hold a single
// record;
final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];

// Because this output will be read during shuffle, its compression codec must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more details.
final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = spilledFileInfo._2();
final TempShuffleBlockId blockId = spilledFileInfo._1();
final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);

// Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
// Our write path doesn't actually use this serializer (since we end up calling the `write()`
// OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
// around this, we pass a dummy no-op serializer.
final SerializerInstance ser = DummySerializerInstance.INSTANCE;

writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);

int currentPartition = -1;
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final int partition = sortedRecords.packedRecordPointer.getPartitionId();
assert (partition >= currentPartition);
if (partition != currentPartition) {
// Switch to the new partition
if (currentPartition != -1) {
writer.commitAndClose();
spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
}
currentPartition = partition;
writer =
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
}

final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
final Object recordPage = taskMemoryManager.getPage(recordPointer);
final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage);
long recordReadPosition = recordOffsetInPage + 4; // skip over record length
while (dataRemaining > 0) {
final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
Platform.copyMemory(
recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
writer.write(writeBuffer, 0, toTransfer);
recordReadPosition += toTransfer;
dataRemaining -= toTransfer;
}
writer.recordWritten();
}

if (writer != null) {
writer.commitAndClose();
// If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
// then the file might be empty. Note that it might be better to avoid calling
// writeSortedFile() in that case.
if (currentPartition != -1) {
spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
spills.add(spillInfo);
}
}

if (!isLastFile) { // i.e. this is a spill file
// The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
// are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
// relies on its `recordWritten()` method being called in order to trigger periodic updates to
// `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
// counter at a higher-level, then the in-progress metrics for records written and bytes
// written would get out of sync.
//
// When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
// in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
// metrics to the true write metrics here. The reason for performing this copying is so that
// we can avoid reporting spilled bytes as shuffle write bytes.
//
// Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
// Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
// This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten());
taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten());
}
}

这个函数接收一个参数代表是否是最后一次写,因为前面都是满足条数了写一次,正常来说,执行到最后的时候,会剩下一点在内存里面,当要把最后剩下的这点内存数据写入文件时,此参数为true,不过此参数影响不大,只是影响一些实时指标统计。下面这一行非常重要:

1
2
3
// This call performs the actual sort.
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();

这一行真正的发生了sort,下面跟进一下,看一下getSortedIterator的源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\ShuffleInMemorySorter.java
public ShuffleSorterIterator getSortedIterator() {
int offset = 0;
if (useRadixSort) {
offset = RadixSort.sort(
array, pos,
PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
} else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
array.getBaseOffset() + pos * 8L,
(array.size() - pos) * 8L);
LongArray buffer = new LongArray(unused);
Sorter<PackedRecordPointer, LongArray> sorter =
new Sorter<>(new ShuffleSortDataFormat(buffer));

sorter.sort(array, 0, pos, SORT_COMPARATOR);
}
return new ShuffleSorterIterator(pos, array, offset);
}

这段源码里面会有两种排序算法,但是都是使用partition来进行排序,不再继续深入,返回到writeSortedFile函数里面。后面的逻辑就是把已经按照partition排好序的文件写入到真正的block里面。

以上就是遍历所有记录写入到文件的过程了,然后再回到一开始的入口函数:UnsafeShuffleWriter的write方法,继续往下看:closeAndWriteOutput(),再进入这个方法里面看一下源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\UnsafeShuffleWriter.java
void closeAndWriteOutput() throws IOException {
assert(sorter != null);
updatePeakMemoryUsed();
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = Utils.tempFileWith(output);
try {
partitionLengths = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
}
}
}
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

可以看到,与其它的ShuffleHandle类似的处理,进行归并,然后写一个结果文件和一个index文件,完成了mapStatus.

6.2.1.3 BypassMergeSortShuffleHandle

如果使用的是BypassMergeSortShuffleHandle,它的写操作是什么样子的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
//位置:spark\spark-branch-2.1.0\core\src\main\java\org\apache\spark\shuffle\sort\BypassMergeSortShuffleWriter.java
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
partitionWriters = new DiskBlockObjectWriter[numPartitions];
partitionWriterSegments = new FileSegment[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
partitionWriters[i] =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
// included in the shuffle write time.
writeMetrics.incWriteTime(System.nanoTime() - openStartTime);

while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}

for (int i = 0; i < numPartitions; i++) {
final DiskBlockObjectWriter writer = partitionWriters[i];
partitionWriterSegments[i] = writer.commitAndGet();
writer.close();
}

File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
File tmp = Utils.tempFileWith(output);
try {
partitionLengths = writePartitionedFile(tmp);
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

根据numPartitions来创建对应数目的DiskBlockObjectWriter,即写临时文件的handler,每个handler需要32KB 的Buffer。
主要过程:遍历该map任务的分区数据,然后通过partitioner.getPartition(key)确定该record写入到那个临时文件,然后通过DiskBlockObjectWriter写入到临时文件。数据都写入临时文件之后,再把所有临时文件归并为一个最终的文件和其对应的索引文件。

6.2.2 将结果通知Driver

下面继续回到TaskRunner类里面,在task自身的run方法执行结束后,会清理一下现场,接下来就是如果处理刚刚得到的结果了。

先把结果序列化一下,然后看一下结果的大小是否超过了最大值(可以通过spark.driver.maxResultSize来配置),如果超过了,会转成IndirectTaskResult类型的结果,否则,看一下结果的大小是否超过了maxDirectResultSize(可以通过spark.task.maxDirectResultSize来配置),如果超过了就存储在blockManager中,并将存储的位置和大小返回成IndirectTaskResult类型的结果;如果以上两者都是小于等于的关系,那么会直接返回结果,代码片段如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
//位置:spark\spark-branch-2.1.0\core\src\main\scala\org\apache\spark\executor\Executor.scala
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize > maxDirectResultSize) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId,
new ChunkedByteBuffer(serializedDirectResult.duplicate()),
StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
serializedDirectResult
}
}

execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

最终,通过execBackend.statusUpdate来告知driver计算结束了,并给出计算结果的信息。告知的方式是通过driverRef.send(msg)去发送一个消息,与之前Driver发送Task的方式类似,源码如下:

1
2
3
4
5
6
7
8
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
val msg = StatusUpdate(executorId, taskId, state, data)
driver match {
case Some(driverRef) => driverRef.send(msg)
case None => logWarning(s"Drop $msg because has not yet connected to driver")
}
}

6.3 Driver处理结果

在Executor运行完一个Task后,会给Driver发送消息StatusUpdate类型的消息,然后回到Driver的源码继续这个流程,在spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala中会收到这个消息,处理的逻辑在这个CoarseGrainedSchedulerBackend这个类的receive方法中,源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
executorInfo.freeCores += scheduler.CPUS_PER_TASK
makeOffers(executorId)
case None =>
// Ignoring the update since we don't know about the executor.
logWarning(s"Ignored task status update ($taskId state $state) " +
s"from unknown executor with ID $executorId")
}
}

处理的逻辑是,调用scheduler.statusUpdate(taskId, state, data.value)。

然后跳转进入到statusUpdate这个方法内部,最主要的逻辑是,从这个Task的TaskSetManager内部除去这个TaskID,然后把记录的对应excutor中的task数量减一,再然后,在taskSet里面删除这个运行中的task,再然后,调用了taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)方法来处理结果。

enqueueSuccessfulTask在类TaskResultGetter中,源码位置是:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

  • 如果result是directResult,那么直接取出结果;
  • 如果是IndirectTaskResult,那么需要调用 blockManager.getRemoteBytes() 去 fech 实际的 result。在获得了result之后,就会更新一下当前收到的所有结果的大小。再然后,就去调用scheduler.handleSuccessfulTask(taskSetManager, tid, result)方法来进行下一步的处理。

然后handleSuccessfulTask里面调用了sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info)来处理,再进入到DAGScheduler中的taskEnded方法,这里向eventProcessLoop传递了一个CompletionEvent事件,源码如下:

1
2
3
4
5
6
7
8
9
10
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: Seq[AccumulatorV2[_, _]],
taskInfo: TaskInfo): Unit = {
eventProcessLoop.post(
CompletionEvent(task, reason, result, accumUpdates, taskInfo))
}

然后再回到一开始处理JobSubmitted事件的地方,进入到DAGSchedulerEventProcessLoop类里面的doOnReceive方法里面,看到了对CompletionEvent事件的处理,即dagScheduler.handleTaskCompletion(completion),然后再转战到handleTaskCompletion方法里面,看到对于event的reason为Success的类型中的task的类型不同是有不同的处理逻辑的:

  • 对于ResultTask类型的Task,会通过updateAccumulators(event)来累加的计算各个结果,如果完成的task的数量已经达到这个Stage的Partition的数量,那么,标记这个Stage为完成状态。由于这个Task的类型是ResultTask,那也就代表着,这个Stage完成之后,整个Job完成了。

  • 对于ShuffleMapTask类型的Task,先记录一下输出结果的位置。如果正在运行中的stage列表包含此stage,并且此stage已经没有等待处理的Partition了,那么标记此Stage为成功结束,然后在mapOutputTracker中记录刚刚的输出结果,以后于后续的shuffle read的时候读取。

对于两种类型的Task都会调用submitWaitingStages()方法来提交剩下的其它的Stage,但是在ResultStage结束后,已经没有在Waiting的Stage了,程序也就正常运行结束;在ShuffleMapStage运行结束后,会通过submitWaitingStages()方法来提交接下来的Stage。submitWaitingStages()方法内部逻辑非常简单,就是遍历一下当前等待执行的Stage列表,然后调用一开始时已经讲过的submitStage(stage)方法,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
private def submitWaitingStages() {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logTrace("Checking for newly runnable parent stages")
logTrace("running: " + runningStages)
logTrace("waiting: " + waitingStages)
logTrace("failed: " + failedStages)
val waitingStagesCopy = waitingStages.toArray
waitingStages.clear()
for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) {
submitStage(stage)
}
}

6.4 ShuffleRead

通过上面的流程分析,已经能够得到整个Spark任务的计算过程是:Driver创建TaskSet,然后交给Executor去执行,之后的执行结果再告知Driver。这个流程已经能够完成一个基本的Stage了,但是,Spark中的Stage划分是以Shuffle为基础的,也就是说,一个非常常见的情况是,后面的Stage很可能是需要前面的Stage的输出结果来作为输入的。

下面再讨论一下,如何进行ShuffleRead。

reducer 首先要知道 parent stage 中 ShuffleMapTask 输出的 FileSegments 在哪个节点。这个信息在 ShuffleMapTask 完成时已经送到了 driver 的 mapOutputTrackerMaster,并存放到了 mapStatuses: HashMap 里面,给定 stageId,可以获取该 stage 中 ShuffleMapTasks 生成的 FileSegments 信息 Array[MapStatus],通过 Array(taskId) 就可以得到某个 task 输出的 FileSegments 位置(blockManagerId)及每个 FileSegment 大小。

整个故事的起点要从ShuffledRDD开始,这个类位于:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala

这个类中的compute方法如下:

1
2
3
4
5
6
7
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}

在上面的方法中的shuffleManager是一个SortShuffleManager的实例(目前的Spark中只有这一种),然后在SortShuffleManager类里面,调用的getReader方法的源码如下:

1
2
3
4
5
6
7
8
9
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

然后,刚刚的代码中的read()方法,自然是会调用BlockStoreShuffleReader中的read()方法,这个类的位置是:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
//位置:spark/spark-branch-2.0/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapForCompression(blockId, inputStream)
}

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())

// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}

在这个read()方法里面,首先会创建一个ShuffleBlockFetcherIterator的实例对象,然后进入ShuffleBlockFetcherIterator,发现它会在构造时,调用initialize方法,在这个initialize方法中,首先会调用splitLocalRemoteBlocks方法来把所有请求分成远程和本地两种类型,以及明确一共要获取多少个块,然后通过方法fetchUpToMaxBytes()来不间断(这里有两个参数可以控制速率,分别的:spark.reducer.maxSizeInFlight和spark.reducer.maxReqsInFlight”)的获取所有远程的文件块。再通过fetchLocalBlocks()来获取本地的文件块。

关于ShuffleBlockFetcherIterator有一个非常优雅的点,就是它继承了Iterator类,自己实现了next()方法。在这个方法里面,如果当前获取的块还没有通过网络拿到,那么就take()等待,如果已经通过网络获取到了,那么就直接返回下一个文件块。

下面继续回到BlockStoreShuffleReader的read方法里面,在数据全部拉取结束后,判断一下是否定义了aggregator,如果有的话,先看一下是否有mapSideCombine属性,如果有的话,就调用combineCombinersByKey来进行结果计算;如果没有mapSideCombine属性,就直接combineValuesByKey来进行结果计算。

如果定义了keyOrdering的属性,那么创建一个ExternalSorter对象,用来给所有的数据进行外部排序。

总结一下:由于Spark的lazy的计算模型,在Task在Executor端被执行的时候,才会真正的开始进行计算,也正在在这个逻辑里面,对于ShuffledRDD会首先去进行自己的数据拉取,ShuffleRead的逻辑已经在很久之前写在自己的compute方法里面了。

7. 结束点

spark.stop()是整个程序的结束点,在SparkSession里面调用的stop方法,会直接调用sparkContext里面的stop的方法。如下:

1
2
3
def stop(): Unit = {
sparkContext.stop()
}

SparkSession里面调用的stop方法会进行各种各样的清理工作。