Spark Core(十七)Spark的Shuffle原理与源码分析

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Suubyy/article/details/82023369
  1. Shuffle的定义
    1. 我们都知道Spark是一个基于内存的、分布式的、迭代计算框架。在执行Spark作业的时候,会将数据先加载到Spark内存中,内存不够就会存储在磁盘中,那么数据就会以Partition的方式存储在各个节点上,我们编写的代码就是操作节点上的Partiton数据。之前我们也分析了怎么我们的代码是怎么做操Partition上的数据,其实就是有DriverTask发送到每个节点上的Executor上去执行,在操作Partiton上的数据时候,遇到Action操作的时候会生成一个新的Partition,而这个Partition是由多个节点上的Partition组成的,这样就实现了跨界点,我们管这种操作就叫SparkShuffle操作。其实比较比较通俗的来说,其实就是上一个Stage的输出,下一个Stage拉取这个输出的过程就是Shuffle
  2. Shuffle的原理

    1. 普通Shuffle

      1. Spark1.6的版本之前,Shuffle就是一个没有经过优化的Shuffle,它的原理就是每一个ShuffleMapTask都会根据ResultTask数量在内存创建多个bucket,并且会在该ShuffleMapTask节点上的磁盘上为每一个bucket创建一个blockFile文件,如果是4个ResultTask,那么就会有4*4=16blockFile文件。
      2. Task的运行结果全部写入到缓存bucket中,然后将bucket的数据全部刷新到blockFile文件中。
      3. ShuffleMapTask就会将Task的执行状态以及结果存储的地址封装成MapStatus对象发送给Driver
      4. ResultTask运行的时候,就会向Driver发送请求来获取它所依赖的blockFile文件的信息。
      5. ResultTask根据这些信息利用BlockManger来将本地数据或者远程的数据通过网络或者直接读取的方式拉取过来存到内存中作为自己的输入数据,ResultTask计算结束以后,将结果返回给我们。这就是早起版本Shuffle的原理。

        这里写图片描述

    2. 优化后的Shuffle
      1. Spark1.6以后,对shuffle进行了优化。优化的原理是根据core来优化的,因为运行在每一个Executor上的Task都是并行运行的,例如有两个core,如果有4Task4ResultTask。这个时候只能并行运行两个Task,然后ShuffleMapTask会根据ResultTask的数量来创建8bucket,然后在根据bucket在本地磁盘创建8BlockFile
      2. 当另外两个Task执行的时候也会根据ResultTask创建8bucket,但是这个时候,不会在本地磁盘上创建BlockFile了,而是将结果追加到前两个Task对应的8Block这里写代码片File文件中。
      3. 将结果写入BlockFile的时候,首先不会等到结果全部写入内存以后在刷新到磁盘上,而是当内存达到一定的阈值就会将数据刷新到磁盘中个,这样防止了OOM
        这里写图片描述
  3. Shuffle的写源码分析

    1. ShuffleMapTaskrunTask方法:该方法中 writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])方法就是Shuffle写操作的开始,Spark默认的写操作是HashShuffleWriter

      //该方法的作用执行Task,然后将结果返回给调度器
       //其中MapStatus封装了块管理器的地址,以及每个reduce的输出大小
       //以便于传递reduce的任务
       override def runTask(context: TaskContext): MapStatus = {
         //记录反序列化RDD的开始时间
          val deserializeStartTime = System.currentTimeMillis()
          //创建一个序列化器
          val ser = SparkEnv.get.closureSerializer.newInstance()
           // 反序列化广播变量来的得到RDD
          val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
            ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
          _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
      
          metrics = Some(context.taskMetrics)
          var writer: ShuffleWriter[Any, Any] = null
          try {
            //获取ShuffleManager
            val manager = SparkEnv.get.shuffleManager
            //利用ShuffleManager获取ShuffleWriter,ShuffleWriter的功能就是将Task计算的结果
            //持久化到shuffle文件中,作为子Stage的输入
            writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
            //通过ShuffleWriter将结果进行持久化到shuffle文件中,作为子Stage的输入
            //rdd.iterator(partition, context)这个方法里就会执行我们自己编写的业务逻辑代码
            writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
            //关闭writer,将元数据写入MapStatus中,然后返回
            writer.stop(success = true).get
          } catch {
            case e: Exception =>
              try {
                if (writer != null) {
                  writer.stop(success = false)
                }
              } catch {
                case e: Exception =>
                  log.debug("Could not stop writer", e)
              }
              throw e
          }
        }
    2. HashShuffleWriter中的writer方法:

       //ShuffleMapTask计算出来的结果写入磁盘的方法
      override def write(records: Iterator[Product2[K, V]]): Unit = {
          //判断是否对ShuffleMapTask计算后得到的Partition对应的Iterator集合进行Map端的本地聚合
          val iter = if (dep.aggregator.isDefined) {
              //如果dep.mapSideCombine和dep.aggregator.isDefined为true,那么就进行Map端的本地聚合
            if (dep.mapSideCombine) {
              //开始本地聚合的方法
              dep.aggregator.get.combineValuesByKey(records, context)
            } else {
              records
            }
          } else {
            require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
            records
          }
          //遍历聚合以后的数据,调用partitioner,默认是HashPartitioner,生成bucketId
          //然后将数据写入bucketId对应的Bucket中去
          for (elem <- iter) {
            val bucketId = dep.partitioner.getPartition(elem._1)
            //shuffle:它就是ShuffleWriterGroup对象,其实就是为ShuffleWriterGroup定义的一组writer
            //它是由FileShuffleBlockResolver这类调用forMapTask方法生成的,
            // shuffle.writers(bucketId)根据bucketId从ShuffleWriterGroup对应的一组writer中找出对应的DiskBlockObjectWriter
            //然后将数据写入对应的文件中
            shuffle.writers(bucketId).write(elem._1, elem._2)
          }
      }
    3. FileShuffleBlockResolver类的forMapTask方法:该方法的作用就是根据Map Task得到一个ShuffleWriterGroup

      def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer,
            writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = {
         //为每个ShuffleMapTask实例化一个ShuffleWriterGroup   
         new ShuffleWriterGroup {
            //实例化ShuffleState并保存shuffleId与ShuffleState的对应关系
            shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers))
            //根据shuffleId获取ShuffleState
            private val shuffleState = shuffleStates(shuffleId)
      
            val openStartTime = System.nanoTime
            val serializerInstance = serializer.newInstance()
            //实例化Writer:DiskBlockObjectWriter
            val writers: Array[DiskBlockObjectWriter] = {
              Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
                //生成blockId
                val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
                //根据blcokId生成blockFile,用于存储写入的数据
                val blockFile = blockManager.diskBlockManager.getFile(blockId)
                //生成一个临时文件,也就是把数据写入那个目录里
                val tmp = Utils.tempFileWith(blockFile)
      
                //生成writer:实例化Writer:DiskBlockObjectWriter,用于将数据写入磁盘
                //这里的BufferSize的默认大小是32kb,可以通过spark.shuffle.file.buffer重新设置大小
                blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
              }
            }
            // Creating the file to write to and creating a disk writer both involve interacting with
            // the disk, so should be included in the shuffle write time.
            writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
      
            override def releaseWriters(success: Boolean) {
              shuffleState.completedMapTasks.add(mapId)
            }
          }
      }
  4. Shuffle的读源码分析

    1. 在上一讲Spark(十六)Executor执行Task的原理与源码分析(二)中我们已经对ResultTask进行了源码的分析,ResultTask里的runTask方法就是开始计算最终我们想要的结果,那么既然要计算结果就需要从远程或者本地拉去上一个Stage处理过后的结果,也就是Shuffle写入到磁盘中的数据。从上一篇的源码分析可以找出,ResultTask拉取数据的起点是ShuffledRDD类里的compute方法。
    2. ShuffleRDDcompute方法

      override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
          val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
          //回调用shuffleManger的getReader方法获取SortShuffleManager,
          //调用它的read方法拉取数据
          SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
            .read()
            .asInstanceOf[Iterator[(K, C)]]
      }
    3. SortShuffleManager里的getRead方法

       override def getReader[K, C](
            handle: ShuffleHandle,
            startPartition: Int,
            endPartition: Int,
            context: TaskContext): ShuffleReader[K, C] = {
          //创建一个BlockStoreShuffleReader
          new BlockStoreShuffleReader(
            handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
      }
    4. BlockStoreShuffleReaderread方法

      override def read(): Iterator[Product2[K, C]] = {
          //实例化ShuffleBlockFetcherIterator,在实例化这个对象的时候,回调用它内部的initialze方法,这个方法会调用splitLocalRemoteBlocks方法来路由拉取数据的策略,
          //拉取数据策略分为两种,一种是本地策略,一种是远程策略。
          val blockFetcherItr = new ShuffleBlockFetcherIterator(
            context,
            blockManager.shuffleClient,
            blockManager,
            //通过消息发送获取ShuffleMapTask存储数据的位置
            mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
            // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
            //设置每次拉取数据的大小
            SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
            //设置从远程节点拉取快的数量
            SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
      
          // 根据配置对流进行压缩和加密,构建一个包装流
          val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
            serializerManager.wrapStream(blockId, inputStream)
          }
          //构建一个序列化器
          val serializerInstance = dep.serializer.newInstance()
      
          // 为每个包装流创建一个键/值迭代器。
          val recordIter = wrappedStreams.flatMap { wrappedStream =>
            // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
            // NextIterator. The NextIterator makes sure that close() is called on the
            // underlying InputStream when all records have been read.
            serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
          }
      
          // 每条记录读取后更新指标。这样在UI界面上就能看见相关信息
          val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
          //生成一个完整的迭代器
          val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
            recordIter.map { record =>
              readMetrics.incRecordsRead(1)
              record
            },
            context.taskMetrics().mergeShuffleReadMetrics())
      
          // 为了这个任务可以取消,那么就必须使用可中断的迭代器
          val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
      
          //判断数据聚合操作是否被定义
          val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
            //判断是否启用数据的聚合操作,因为这个Stage相于下一个Stage是Map端
            if (dep.mapSideCombine) {
      
              val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
              //如果启用就调用combineCombinersByKey方法
              dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
            } else {//不启用数据聚合操作
              val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
              //如果不启用就调用combineValueByKey方法
              dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
            }
          } else {//如果聚合操作没有被定义就会报错
            require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
            interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
          }
      
          // 根据keyOrdering属性判断是否对输出结果
          dep.keyOrdering match {
            //如果keyOrdering不为空,就对输出结果进行排序
            case Some(keyOrd: Ordering[K]) =>
              // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
              // the ExternalSorter won't spill to disk.
              //创建ExternalSorter对象对数据进行排序,如果spark.shuffle.spill没有开启
              //ExternalSorter不会将数据持久化到磁盘上的
              val sorter =
                new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
              sorter.insertAll(aggregatedIter)
              context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
              context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
              context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
              CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
            case None =>
              aggregatedIter
          }
        }
    5. MapOutputTracker类里的getMapSizeByExecutorId方法,该方法的作用就是告诉Executor去获取每个shuffle block服务器的Url和输出大小。

        def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
            : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
          logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
          //根据ShuffleId获取ShuffleBlock的元数据,包含数据地址和数据的大小
          val statuses = getStatuses(shuffleId)
          //
          statuses.synchronized {
            //调用MapOutputTracker的converMapStatus方法,给定一组映射状态和一系列的映射分区,
            //然后返回一个Tuple2序列,该序列里的的元素是元组,元组的第一个元素是BlockManagerId,第二个元素是也是一个元组
            //这个元组的第一个元素是BlockId,第二个元素是block快的大小
            //也就是说这个方法返回的是数据在哪个节点以及这个节点的那些block
            return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
          }
        }
    6. MapOutputTracker类里的getStatus方法,该方法的作用就是利用ShuffleId过去对应的MapStatusBlock的元数据),它的原理就是首先从给本地获取MapStatus,如果没有就通过网络拉取MapStatus

       //这个方法利用了算法
       //这个方法就是根据shuffleId获取MapStatus方法
       private def getStatuses(shuffleId: Int): Array[MapStatus] = {
          //根据shuffleId获取本地的MapStatus数组,因为当ResultTask拉取MapStatus的时候
          //会把它放到内存缓存中。
          val statuses = mapStatuses.get(shuffleId).orNull
          //如果status为空,那么就说明本地没有对应的status,这样就会利用网络拉去MapStatus
          if (statuses == null) {
            logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
            //拉取的开始时间
            val startTime = System.currentTimeMillis
            //定义MapStatus数组
            var fetchedStatuses: Array[MapStatus] = null
            //由于是并行处理,可能其他Task也在利用网络拉取MapStatus
            //所以避免数据的同步问题需要加上synchronize关键字
            fetching.synchronized {
              // Someone else is fetching it; wait for them to be done
              while (fetching.contains(shuffleId)) {
                try {
                  //如果其他Task正在拉取数据,那么就等待它完成再继续执行
                  fetching.wait()
                } catch {
                  case e: InterruptedException =>
                }
              }
      
              // Either while we waited the fetch happened successfully, or
              // someone fetched it in between the get and the fetching.synchronized.
              //等待过后继续调用mapStatus的get方法获取MapStatus
              fetchedStatuses = mapStatuses.get(shuffleId).orNull
              if (fetchedStatuses == null) {
                // We have to do the fetch, get others to wait for us.
                //加入到内存缓存中,等待这个线程编程等待状态,因为上边的while循环就是
                //将线程编程等待状态,需要调用notifyAll来唤醒所有的线程
                fetching += shuffleId
              }
            }
            //如果fetchedStatuses还是等于空的话,就会真正的开始从MapOutputTracker中获取MapStatus
            if (fetchedStatuses == null) {
              // We won the race to fetch the statuses; do so
              logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
              // This try-finally prevents hangs due to timeouts:
              try {
                //调用askTracker方法,发送GetMapOutputStatuses消息获取MapStatus
                val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
                //反序列化传过来的MapStatus数组对象
                fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
                logInfo("Got the output locations")
                //将拉取过来的MapStatus放入到内存缓冲中,内存缓存结构为HashMap
                mapStatuses.put(shuffleId, fetchedStatuses)
              } finally {
                fetching.synchronized {
                  //移除当前执行完的fetch
                  fetching -= shuffleId
                  //由于当前线程正在执行的时候,其他线程正在处于等待状态
                  //需要调用notifyAll方法来唤醒其他线程,继续获取MapStatus
                  fetching.notifyAll()
                }
              }
            }
            logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
              s"${System.currentTimeMillis - startTime} ms")
            //如果fetchedStatuses不等于空的话就直接返回,如果为空就抛出异常
            if (fetchedStatuses != null) {
              return fetchedStatuses
            } else {
              logError("Missing all output locations for shuffle " + shuffleId)
              throw new MetadataFetchFailedException(
                shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
            }
      
          //如果本地有ShuffleId对应的MapStatus就直接返回
          } else {
            return statuses
          }
        }
    7. MapOutputTracker类里的askTracker方法,该方法的作用就是向MapOutputTrackerMasterEndpoint发送GetMapOutputStatus消息请求获取MapStatus

      protected def askTracker[T: ClassTag](message: Any): T = {
          try {
            //向MapOutputTrackerMasterEndpoint发送GetMapOutputStatus消息,并设定请求失败重试的次数,超时时间使用默认的
            trackerEndpoint.askWithRetry[T](message)
          } catch {
            case e: Exception =>
              logError("Error communicating with MapOutputTracker", e)
              throw new SparkException("Error communicating with MapOutputTracker", e)
          }
      }
    8. MapOutputTrackerMasterEndpoint类里的receiveAndReply方法,该方法的作用就是接收ResultTask发送过来的GetMapOutputStatus消息,调用MapOutputTrackerMasterpost方法,获取MapStatus

      override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
          //当MapOutputTrackerMasterEndpoint接收到ResultTask发送的GetMapOutputStatus消息后,调用MapOutputTrackerMaster的post方法
          case GetMapOutputStatuses(shuffleId: Int) =>
            val hostPort = context.senderAddress.hostPort
            logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
            //MapOutputTrackerMaster调用post方法,获取MapStatus
            val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))
      
          case StopMapOutputTracker =>
            logInfo("MapOutputTrackerMasterEndpoint stopped!")
            context.reply(true)
            stop()
      }
    9. MapOutputTrackerMasterpost方法,该方法的作用就是将请求用到的GetMapOutputMessage消息放入到队列里,后续会利用多线程的方式来执行请求,获取MapStatus

      def post(message: GetMapOutputMessage): Unit = {
          //不直接发送请求,而是将GetMapOutputMessage放入请求MapStatus对队列中,利用多线程的方式来请求MapStatus
          mapOutputRequests.offer(message)
      }
    10. MessageLoop类是一个发送MapOutputMessage消息的循环体,利用多线程的方式循环调用getSerializedMapOutputStatuses方法从本地获取MapStatus,然后返回给ResultTask

       //该类是继承了Runnable抽象类的一个调度消息的循环体
       private class MessageLoop extends Runnable {
          override def run(): Unit = {
            try {
              while (true) {
                try {
                  val data = mapOutputRequests.take()
                   if (data == PoisonPill) {
                    // Put PoisonPill back so that other MessageLoops can see it.
                    mapOutputRequests.offer(PoisonPill)
                    return
                  }
                  val context = data.context
                  val shuffleId = data.shuffleId
                  val hostPort = context.senderAddress.hostPort
                  logDebug("Handling request to send map output locations for shuffle " + shuffleId +
                    " to " + hostPort)
                  //获取MapStatus
                  val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
                  //返回数据
                  context.reply(mapOutputStatuses)
                } catch {
                  case NonFatal(e) => logError(e.getMessage, e)
                }
              }
            } catch {
              case ie: InterruptedException => // exit
            }
          }
        }
    11. MapOutputTracker类里的converMapStatus方法,该方法的作用就是根据参数组装数据结构。

      //该方法的作用就是利用传过来的参数组装成一个数据结构,以供后续使用
      //该数据结构就是一个序列里边的结构就是block数据在哪个节点上以及对应哪些block和block的大小
      private def convertMapStatuses(
            shuffleId: Int,
            startPartition: Int,
            endPartition: Int,
            statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
          assert (statuses != null)
          val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
          for ((status, mapId) <- statuses.zipWithIndex) {
            if (status == null) {
              val errorMessage = s"Missing an output location for shuffle $shuffleId"
              logError(errorMessage)
              throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
            } else {
              for (part <- startPartition until endPartition) {
                splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
                  ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
              }
            }
          }
      
          splitsByAddress.toSeq
      }
    12. ShuffleBlockFetcherIterator类里的initilize方法

      //在初始化ShuffleBlockFetcherIterator对象时候会调用这个initilize方法
      private[this] def initialize(): Unit = {
          // Add a task completion callback (called in both success case and failure case) to cleanup.
          context.addTaskCompletionListener(_ => cleanup())
      
          // Split local and remote blocks.
          val remoteRequests = splitLocalRemoteBlocks()
          // Add the remote requests into our queue in a random order
          //打散远程的block的容器中的元素,放入队列中
          fetchRequests ++= Utils.randomize(remoteRequests)
          assert ((0 == reqsInFlight) == (0 == bytesInFlight),
            "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
            ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
      
          // 从远程拉取block数据
          fetchUpToMaxBytes()
      
          val numFetches = remoteRequests.size - fetchRequests.size
          logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
      
          // 从本地拉取block数据
          fetchLocalBlocks()
          logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
        }
      
    13. ShuffleBlockFetcherIterator类里的splitLocalRemoteBlock方法,该方法的作用就是根据block所在的位置不同,封装不同的block信息,为后续拉取block数据做准备

      private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
          // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
          // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
          // nodes, rather than blocking on reading output from one node.
          //设置每次从5个节点上同时拉去数据
          val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
          logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
      
          //创建FetcherRequest容器
          //FetchRequest:封装了远程block块的信息
          val remoteRequests = new ArrayBuffer[FetchRequest]
      
          // Tracks total number of blocks (including zero sized blocks)
          //记录块的总数
          var totalBlocks = 0
          for ((address, blockInfos) <- blocksByAddress) {
            totalBlocks += blockInfos.size
            //判断当前block块是否在本地
            if (address.executorId == blockManager.blockManagerId.executorId) {
      
              //过滤掉为0的block块,并把blockId缓存在内存中
              localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
              //记录要获取本地的块的总数
              numBlocksToFetch += localBlocks.size
      
            //如果块不在本地,那么就需要从远程拉去块的信息
            } else {
              //生成blockInfos的迭代器
              val iterator = blockInfos.iterator
              //block数据量的边界值
              var curRequestSize = 0L
              var curBlocks = new ArrayBuffer[(BlockId, Long)]
              while (iterator.hasNext) {
                //获取远程block的Id以及大小
                val (blockId, size) = iterator.next()
                // 判断远程block块大小是否为0
                if (size > 0) {
                  //将block信息加入到容器中
                  curBlocks += ((blockId, size))
                  //将blockId加入到内存缓存中,缓存结构为HashSet
                  remoteBlocks += blockId
                  //需要从远程拉取block块的个数
                  numBlocksToFetch += 1
                  //记录循环到现在为止的所有块的大小,目的是为了定义一个边界,
                  //这个边界是为了防止拉取block数据的时候超出最大的允许拉取的block数据量
                  curRequestSize += size
                } else if (size < 0) {
                  //如果远程的块的大小为0,跑出异常
                  throw new BlockException(blockId, "Negative block size " + size)
                }
                //因为拉去远程block块的时候只能并行从5个节点上拉取数据,当curRequestSize大于等于最大的请求数量
      
                if (curRequestSize >= targetRequestSize) {
                  //就会将block块封装成FetcherRequest然后加入到容器总
                  remoteRequests += new FetchRequest(address, curBlocks)
                  //相当于清空curBlocks集合
                  curBlocks = new ArrayBuffer[(BlockId, Long)]
                  logDebug(s"Creating fetch request of $curRequestSize at $address")
                  //将curRequestSize
                  curRequestSize = 0
                }
              }
              // 因为block块信息遍历到最后curRequestSize >= targetRequestSize这个不成立
              //所以就把最后一个block封装成FetchRequest加入到remoteRequests容器中
              if (curBlocks.nonEmpty) {
                remoteRequests += new FetchRequest(address, curBlocks)
              }
            }
          }
          logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
          remoteRequests
        }
    14. ShuffleBlockFetcherIterator类里的fetchUpToMaxBytes方法,该方法的作用就是循环远程消息队列里的block信息,发送请求获取block信息

      private def fetchUpToMaxBytes(): Unit = {
          //while循环远程队列里的block信息,向远程发送请求获取block数据
          while (fetchRequests.nonEmpty &&
            (bytesInFlight == 0 ||
              (reqsInFlight + 1 <= maxReqsInFlight &&
                bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
            //调用sendRequest方法,发送请求获取block数据
            sendRequest(fetchRequests.dequeue())
          }
        }
    15. ShuffleBlockFetcherIterator类里的sendRequest方法,该方法的作用是发送请求获取block数据

       private[this] def sendRequest(req: FetchRequest) {
          logDebug("Sending request for %d blocks (%s) from %s".format(
            req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
          bytesInFlight += req.size
          reqsInFlight += 1
      
          // so we can look up the size of each blockID、
          //将blockid与block大小的元组结构转换成Map结构
          val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
          //将block的Id放到HashSet容器中
          val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
          val blockIds = req.blocks.map(_._1.toString)
          //  请求的远端地址
          val address = req.address
          //ShuffleClient是一个抽象类,默认调用的是BlockTransferService这个子类的fetchBlock方法,
          //BlockTransferService里的fetchBlock方法也是一个抽象方法,这个方法是NettyBlockTransferService这个类实现的
          shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
            new BlockFetchingListener {
              override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
                //请求成功
                ShuffleBlockFetcherIterator.this.synchronized {
                  if (!isZombie) {
                    // Increment the ref count because we need to pass this to a different thread.
                    // This needs to be released after use.
                    buf.retain()
                    remainingBlocks -= blockId
                    results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                      remainingBlocks.isEmpty))
                    logDebug("remainingBlocks: " + remainingBlocks)
                  }
                }
                logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
              }
              //请求失败
              override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
                logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
                results.put(new FailureFetchResult(BlockId(blockId), address, e))
              }
            }
          )
        }
      
    16. NettyBlockTransferServicefetcherBlock方法,该方的作用就是拉取block数据。NettyBlockTransferService必须在调用init方法后才能提供服务。这个方法在执行前,必须执行以下步骤才能成功拉取block数据

      1. 创建RpcServer(实际是其子类NettyBlockRpcServer
      2. 创建TransportContext
      3. 创建Rpc 客户端工厂 TransportClientFactory
      4. 创建Netty服务器 TransportServer,可以修改属性spark.blockManager.port改变TransportServer的端口
      override def fetchBlocks(
            host: String,
            port: Int,
            execId: String,
            blockIds: Array[String],
            listener: BlockFetchingListener): Unit = {
          logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
          try {
            val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
              override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
                val client = clientFactory.createClient(host, port)
                new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
              }
            }
      
            val maxRetries = transportConf.maxIORetries()
            if (maxRetries > 0) {
              // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
              // a bug in this code. We should remove the if statement once we're sure of the stability.
              new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
            } else {
              blockFetchStarter.createAndStart(blockIds, listener)
            }
          } catch {
            case e: Exception =>
              logError("Exception while beginning fetchBlocks", e)
              blockIds.foreach(listener.onBlockFetchFailure(_, e))
          }
        }
    17. NettyBlockTransferServiceinit方法,

      override def init(blockDataManager: BlockDataManager): Unit = {
          //初始化NettyRpcServer,用于接受上传或者拉取block数据的请求
          val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
          var serverBootstrap: Option[TransportServerBootstrap] = None
          var clientBootstrap: Option[TransportClientBootstrap] = None
          if (authEnabled) {
            serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager))
            clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager,
              securityManager.isSaslEncryptionEnabled()))
          }
          //初始化TransportContext,既可以创建Netty服务端,又可以创建Netty客户端
          //transportConf主要控制Netty客户端与服务端的线程数量
          //rpcHandle负责客户端请求服务端的时候,提供block的上传下载功能,其实就是NettyBlockRpcServer对象
          transportContext = new TransportContext(transportConf, rpcHandler)
          //实例化一个能够创建Netty客户端的工厂类,用于创建Netty客户端
          clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava)
          server = createServer(serverBootstrap.toList)
          appId = conf.getAppId
          logInfo(s"Server created on ${hostName}:${server.getPort}")
        }
    18. NettyBlockRpcServerreveive方法,该方法的作用就是接受拉取block的请求

       override def receive(
            client: TransportClient,
            rpcMessage: ByteBuffer,
            responseContext: RpcResponseCallback): Unit = {
          val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
          logTrace(s"Received request: $message")
      
          message match {
           //处理拉去block数据请求
            case openBlocks: OpenBlocks =>
              val blocks: Seq[ManagedBuffer] =
                openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
              val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
              logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
              responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)
      
            case uploadBlock: UploadBlock =>
              // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
              val (level: StorageLevel, classTag: ClassTag[_]) = {
                serializer
                  .newInstance()
                  .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
                  .asInstanceOf[(StorageLevel, ClassTag[_])]
              }
              val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
              val blockId = BlockId(uploadBlock.blockId)
              blockManager.putBlockData(blockId, data, level, classTag)
              responseContext.onSuccess(ByteBuffer.allocate(0))
          }
        }

猜你喜欢

转载自blog.csdn.net/Suubyy/article/details/82023369