[spark] Shuffle Write Analysis (Sort Based Shuffle)

This article is based on Spark 2.1 for analysis

foreword

Hash Based Shuffle has been removed from Spark 2.0. For more information, please refer to the Shuffle process . This article will explain Sort Based Shuffle.

The results of ShuffleMapTask (the data of FinalRDD in ShuffleMapStage) will be written to disk for subsequent stages to pull, that is, the entire Shuffle includes Shuffle Write of the former Stage and Shuffle Read of the latter Stage. Due to the large amount of content, this article analyzes Shuffle Write first.

Overview:

  • Write records to the memory buffer (a map maintained by an array), and each insert&update needs to check whether the overflow condition is reached.
  • If overflow writing is required, the data in the collection is sorted according to partitionId and key (if necessary) and then overflowed to a temporary disk file in order, and the memory is released to create a new map to store the data. Each overflow writing is to write a new temporary document.
  • A task finally corresponds to a file. The data still in memory and the spilled file are merged according to the partitionId on the reduce side. After merging, it needs to be aggregated and sorted again (if necessary), and then written to the final file according to the order of the partitions, and return each file. The offset of a partition in the file is finally returned to the driver as a MapStatus object and registered in the MapOutputTrackerMaster, and the subsequent reduce can be accessed through it.

Entrance

The final execution logic of executing a ShuffleMapTask is to call the runTask() method of the ShuffleMapTask class:

override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    // 从广播变量中反序列化出finalRDD和dependency
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    var writer: ShuffleWriter[Any, Any] = null
    try {
      // 获取shuffleManager
      val manager = SparkEnv.get.shuffleManager
      // 通过shuffleManager的getWriter()方法,获得shuffle的writer
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
       // 通过rdd指定分区的迭代器iterator方法来遍历每一条数据,再之上再调用writer的write方法以写数据
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

The finalRDD and dependency are added to the broadcast variables when the Stage is submitted in the DAGScheluer on the Driver side.

Then obtain the shuffleManager through SparkEnv. The default is sort (corresponding to org.apache.spark.shuffle.sort.SortShuffleManager), which can be set through spark.shuffle.manager.

Then the manager.getWriter method is called. In this method, Unsafe Shuffle will be automatically used if the Unsafe Shuffle condition is met. Otherwise, Sort Shuffle will be used. There are several restrictions on using Unsafe Shuffle. There cannot be aggregate operations in the shuffle phase, and the number of partitions cannot exceed a certain size ( 224 −1, this is the maximum parition id that can be encoded), so operators with aggregate operations such as reduceByKey cannot use Unsafe Shuffle.

Here we temporarily discuss the case of Sort Shuffle, that is, what getWriter returns is SortShuffleWriter, let's directly see what happened to writer.write:

override def write(records: Iterator[Product2[K, V]]): Unit = {
    sorter = if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
    // 写内存缓冲区,超过阈值则溢写到磁盘文件
    sorter.insertAll(records)
    // 获取该task的最终输出文件
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
    val tmp = Utils.tempFileWith(output)
    try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      // merge后写到data文件
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      // 写index文件
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }
  • Create different ExternalSorters by judging whether there is a map-side combine, and if so, pass in the corresponding aggregator and keyOrdering as parameters.
  • Call sorter.insertAll(records) to write records into the memory buffer, and overflow to the disk file if the threshold is exceeded.
  • Merge memory records and all files that are spilled to disk and written to the final data file .data.
  • Write the offset of each partition to the index file.

Let's take a closer look at how sorter.inster writes to memory and spills it to disk files:

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined
    // 若需要Combine
    if (shouldCombine) {
      // 获取对新value合并到聚合结果中的函数
      val mergeValue = aggregator.get.mergeValue
      // 获取创建初始聚合值的函数
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      // 通过mergeValue 对已有的聚合结果的新value进行合并,通过createCombiner 对没有聚合结果的新value初始化聚合结果
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      // 遍历records
      while (records.hasNext) {
        addElementsRead()
        kv = records.next()
        // 使用update函数进行value的聚合
        map.changeValue((getPartition(kv._1), kv._1), update)
        // 是否需要spill到磁盘文件
        maybeSpillCollection(usingMap = true)
      }
    // 不需要Combine
    } else {
      // Stick values into our buffer
      while (records.hasNext) {
        addElementsRead()
        val kv = records.next()
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
      }
    }
  }
  • If aggregation is required, traverse the records to get the KV of the record, and aggregate the V of the same K through the changeValue method of the map and according to the update function. The map here is of PartitionedAppendOnlyMap type, which can only add data but cannot delete data. The underlying implementation is a Array, the way to store KV key-value pairs in the array is [K1, V1, K2, V2...]. After each operation, it will be judged whether to spill to disk.

  • If aggregation is not required, put the record directly into the buffer, and then determine whether to overflow to disk.

Let's first see how the map.changeValue method combines data through map:

override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    // 通过聚合算法得到newValue
    val newValue = super.changeValue(key, updateFunc)
    // 跟新对map的大小采样
    super.afterUpdate()
    newValue
  }

The implementation of super.changeValue:

def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    ...
    // 根据k 得到pos
    var pos = rehash(k.hashCode) & mask
    var i = 1
    while (true) {
      // 从data中获取该位置的原来的key
      val curKey = data(2 * pos)  
      // 若原来的key和当前的key相等,则将两个值进行聚合
      if (k.eq(curKey) || k.equals(curKey)) {
        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        return newValue
       // 若当前key对应的位置没有key,则将当前key作为该位置的key
       // 并通过update方法初始化该位置的聚合结果
      } else if (curKey.eq(null)) {
        val newValue = updateFunc(false, null.asInstanceOf[V])
        data(2 * pos) = k
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        // 扩容
        incrementSize()
        return newValue
      // 若对应位置有key但不和当前key相等,即hash冲突了,则继续向后遍历
      } else {
        val delta = i
        pos = (pos + delta) & mask
        i += 1
      }
    }
    null.asInstanceOf[V] // Never reached but needed to keep compiler happy
  }

According to the hashCode of K, hash and mask to get pos, 2 * pos is the position where k should be, 2 * pos + 1 is the position of v corresponding to k, and get the original key where k should be:

  • If the original key and the current k are equal, then the two v are aggregated by the update function and the value of the position is updated
  • If the original key exists but is not equal to the current k, it means that the hash conflicts, and the update pos continues to traverse
  • If the original key does not exist, the current k is used as the key of this position, and the aggregation result corresponding to the k is initialized through the update function, and then the incrementSize() method is used to expand the capacity:

    private def incrementSize() {
      curSize += 1
      if (curSize > growThreshold) {
        growTable()
      }
    }

    With the new curSize, if the current size exceeds the threshold growThreshold (growThreshold is 0.7 times the current capacity capacity), expand the capacity through growTable():

protected def growTable() {
    // 容量翻倍
    val newCapacity = capacity * 2
    require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements")
    //生成新的数组来存数据
    val newData = new Array[AnyRef](2 * newCapacity)
    val newMask = newCapacity - 1
    var oldPos = 0
    while (oldPos < capacity) {
      // 将旧数组中的数据重新计算位置放到新的数组中
      if (!data(2 * oldPos).eq(null)) {
        val key = data(2 * oldPos)
        val value = data(2 * oldPos + 1)
        var newPos = rehash(key.hashCode) & newMask
        var i = 1
        var keepGoing = true
        while (keepGoing) {
          val curKey = newData(2 * newPos)
          if (curKey.eq(null)) {
            newData(2 * newPos) = key
            newData(2 * newPos + 1) = value
            keepGoing = false
          } else {
            val delta = i
            newPos = (newPos + delta) & newMask
            i += 1
          }
        }
      }
      oldPos += 1
    }
    // 替换及跟新变量
    data = newData
    capacity = newCapacity
    mask = newMask
    growThreshold = (LOAD_FACTOR * newCapacity).toInt
  }

Here, an array with twice the capacity is recreated to store the data, the data in the original array is put into the new array by recalculating the position, the data is replaced with the new array, and some new variables are added.

At this point, the aggregation has been completed, and back to the changeValue aspect, the super.afterUpdate() method will be executed next to sample the size of the map:

protected def afterUpdate(): Unit = {
    numUpdates += 1
    if (nextSampleNum == numUpdates) {
      takeSample()
    }
  }

If every time a new record is traversed, the map is sampled to estimate the size. Assuming that one sampling takes 1ms, 100w sampling will take 16.7 minutes, and the performance will be greatly reduced. So here, only when the number of updates reaches nextSampleNum is sampled once by takeSample():

private def takeSample(): Unit = {
    samples.enqueue(Sample(SizeEstimator.estimate(this), numUpdates))
    // Only use the last two samples to extrapolate
    if (samples.size > 2) {
      samples.dequeue()
    }
    // 估计每次跟新的变化量
    val bytesDelta = samples.toList.reverse match {
      case latest :: previous :: tail =>
        (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates)
      // If fewer than 2 samples, assume no change
      case _ => 0
    }
    // 跟新变化量
    bytesPerUpdate = math.max(0, bytesDelta)
    // 获取下次采样的次数
    nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong
  }

The logic for estimating the amount of change with each new update here is: (current map size - the size of the last sampling) / (the number of current updates - the number of updates in the last sampling).

Then calculate the number of updates that need to be sampled next time. The number of times increases exponentially, and the base is 1.1. After the first sampling, the second sampling is performed 1.1 times, and the third sampling is performed after 1.1*1.1 times. By analogy, the growth is slow at the beginning, and the growth span will be very large later.

Here, after the sampling is completed, return to the insetAll method, and then use the maybeSpillCollection method to determine whether spill is needed:

 private def maybeSpillCollection(usingMap: Boolean): Unit = {
    var estimatedSize = 0L
    if (usingMap) {
      estimatedSize = map.estimateSize()
      if (maybeSpill(map, estimatedSize)) {
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else {
      estimatedSize = buffer.estimateSize()
      if (maybeSpill(buffer, estimatedSize)) {
        buffer = new PartitionedPairBuffer[K, C]
      }
    }

    if (estimatedSize > _peakMemoryUsedBytes) {
      _peakMemoryUsedBytes = estimatedSize
    }
  }

The size of the map is estimated by the estimateSize method of the collection. If spill is needed, the data in the collection is spilled to the disk file, and a new object is created for the collection to put the data. Let's first look at the method estimateSize for estimating the size:

 def estimateSize(): Long = {
    assert(samples.nonEmpty)
    val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates)
    (samples.last.size + extrapolatedDelta).toLong
  }

Take the bytePerUpdate updated after the last sampling as the latest average update size, and estimate the current memory occupied: (current update times - update times at the last sampling) * each update size + the size of the last sampling record.

After getting the size of the current collection, call maybeSpill to determine whether spill is needed:

protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = acquireMemory(amountToRequest)
      // 跟新申请到的内存
      myMemoryThreshold += granted 
      // 集合大小还是比申请到的内存大?spill : no spill
      shouldSpill = currentMemory >= myMemoryThreshold
    }
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
    // Actually spill
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      spill(collection)
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      releaseMemory()
    }
    shouldSpill
  }

There are two situations that can lead to spill:

  • The number of records contained in the current collection exceeds numElementsForceSpillThreshold (the default is Long.MaxValue, which can be set by spark.shuffle.spill.numElementsForceSpillThreshold)
  • The number of records contained in the current collection is an integer multiple of 32, and the size of the current collection exceeds the requested memory myMemoryThreshold (the default value of the first application is 5 * 1024 * 1024, which can be set through spark.shuffle.spill.initialMemoryThreshold). At this time It will not spill immediately, but will try to apply for more memory to avoid spill. The memory you try to apply here is 2 times the set size minus the currently applied memory size (the actual applied memory is granted), if you add the original memory Still smaller than the current collection size requires spill.

If spill is needed, call the spill(collection) method to overflow the disk and release the memory with the new spill times.
Follow up the spill method to see its specific implementation:

override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
    // 传入comparator将集合中的数据先根据partition排序再通过key排序后返回一个迭代器
    val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
    // 写到磁盘文件,并返回一个对该文件的描述对象SpilledFile
    val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
    // 添加到spill文件数组
    spills.append(spillFile)
  }

Continue to follow up to see the implementation of spillMemoryIteratorToDisk:

private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
      : SpilledFile = {
    // 生成临时文件和blockId
    val (blockId, file) = diskBlockManager.createTempShuffleBlock()

    // 这些值在每次flush后会被重置
    var objectsWritten: Long = 0
    var spillMetrics: ShuffleWriteMetrics = null
    var writer: DiskBlockObjectWriter = null
    def openWriter(): Unit = {
      assert (writer == null && spillMetrics == null)
      spillMetrics = new ShuffleWriteMetrics
      writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
    }
    openWriter()

    // 按写入磁盘的顺序记录分支的大小
    val batchSizes = new ArrayBuffer[Long]

    // 记录每个分区有多少元素
    val elementsPerPartition = new Array[Long](numPartitions)

    // Flush  writer 内容到磁盘,并更新相关变量
    def flush(): Unit = {
      val w = writer
      writer = null
      w.commitAndClose()
      _diskBytesSpilled += spillMetrics.bytesWritten
      batchSizes.append(spillMetrics.bytesWritten)
      spillMetrics = null
      objectsWritten = 0
    }

    var success = false
    try {
      // 遍历迭代器
      while (inMemoryIterator.hasNext) {
        val partitionId = inMemoryIterator.nextPartition()
        require(partitionId >= 0 && partitionId < numPartitions,
          s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
        inMemoryIterator.writeNext(writer)
        elementsPerPartition(partitionId) += 1
        objectsWritten += 1
        // 元素个数达到批量序列化大小则flush到磁盘
        if (objectsWritten == serializerBatchSize) {
          flush()
          openWriter()
        }
      }
      // 将剩余的数据flush
      if (objectsWritten > 0) {
        flush()
      } else if (writer != null) {
        val w = writer
        writer = null
        w.revertPartialWritesAndClose()
      }
      success = true
    } finally {
        ...
    }
    // 返回SpilledFile
    SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
  }

Create a temporary file and blockID through diskBlockManager, the format of the temporary file name is "temp_shuffle_" + id, traverse the memory data iterator, and call the write method of Writer (DiskBlockObjectWriter), when the number of writes reaches the serialization size, flush to the disk file, And reopen the writer, and follow the new batchSizes and other information.

Finally, a SpilledFile object is returned, which contains the overflowed temporary file File, blockId, the size of each flush to the disk, and the number of data bars corresponding to each partition.

The spill is completed, and the insertAll method is also executed, returning to the write method of the SortShuffleWriter at the beginning:

override def write(records: Iterator[Product2[K, V]]): Unit = {
    ...
    // 写内存缓冲区,超过阈值则溢写到磁盘文件
    sorter.insertAll(records)
    // 获取该task的最终输出文件
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
    val tmp = Utils.tempFileWith(output)
    try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      // merge后写到data文件
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      // 写index文件shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }

Get the last output file name and blockId, file format:

 "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"

Then use the sorter.writePartitionedFile method to write the file, including the merge operation of memory and all spill files, let's take a look at the specific implementation:

def writePartitionedFile(
      blockId: BlockId,
      outputFile: File): Array[Long] = {

    val writeMetrics = context.taskMetrics().shuffleWriteMetrics

    // 跟踪每个分区在文件中的range
    val lengths = new Array[Long](numPartitions)
    // 数据只存在内存中
    if (spills.isEmpty) { 
      val collection = if (aggregator.isDefined) map else buffer
      // 将内存中的数据先通过partitionId再通过k排序后返回一个迭代器
      val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
      // 遍历数据写入磁盘
      while (it.hasNext) {
        val writer = blockManager.getDiskWriter(
          blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
        val partitionId = it.nextPartition()
        //等待一个partition的数据写完后刷新到磁盘文件
        while (it.hasNext && it.nextPartition() == partitionId) {
          it.writeNext(writer)
        }
        writer.commitAndClose()
        val segment = writer.fileSegment()
        // 记录每个partition数据长度
        lengths(partitionId) = segment.length
      }
    } else {
      // 有数据spill到磁盘,先merge
      for ((id, elements) <- this.partitionedIterator) {
        if (elements.hasNext) {
          val writer = blockManager.getDiskWriter(
            blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
          for (elem <- elements) {
            writer.write(elem._1, elem._2)
          }
          writer.commitAndClose()
          val segment = writer.fileSegment()
          lengths(id) = segment.length
        }
      }
    }

    context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
    context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
    context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

    lengths
  }
  • The data only exists in the memory and there is no spill file. According to the incoming comparison function comparator, the data in the collection is first sorted according to the partition, then the key in it is sorted, and an iterator is returned, traversing the iterator to get all recorded, each partition Corresponding to a writer, after the data of a partition is written, it is flushed to the disk file, and the data length of the partition is recorded.
  • The data has a spill file. The method partitionedIterator is used to merge-sort the data in the memory and the spill file and return an iterator (partitionId, the iterator of the data corresponding to the partition). It is also a partition corresponding to a Writer. After writing a partition, flush it again. to disk, and record the length of the partition data.

Next, let's see how the data of the memory and spill file is merge-sorted by the this.partitionedIterator method:

def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
    val usingMap = aggregator.isDefined
    val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
    if (spills.isEmpty) {
      if (!ordering.isDefined) {
        // 只根据partitionId排序,不需要对key排序
        groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
      } else {
        // 需要对partitionID和key进行排序
        groupByPartition(destructiveIterator(
          collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
      }
    } else {
      // Merge spilled and in-memory data
      merge(spills, destructiveIterator(
        collection.partitionedDestructiveSortedIterator(comparator)))
    }
  }

Here, when there is a spill file, the following merge method will be executed. The data in the spill file array and the data in memory are passed into the partitionId and key sorted data iterator, and see merge:

private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
      : Iterator[(Int, Iterator[Product2[K, C]])] = {
    // 每个文件对应一个Reader
    val readers = spills.map(new SpillReader(_)) 
    val inMemBuffered = inMemory.buffered
    (0 until numPartitions).iterator.map { p =>
      // 获取内存中当前partition对应的Iterator
      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
      // 将spill文件对应的partition的数据与内存中对应partition数据合并
      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
      if (aggregator.isDefined) {
        // 对key进行聚合并排序
        (p, mergeWithAggregation(
          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
      } else if (ordering.isDefined) {
        // 排序
        (p, mergeSort(iterators, ordering.get))
      } else {
        (p, iterators.iterator.flatten)
      }
    }
  }

The merge method merges the memory data of the partitions belonging to the same reduce side with the spill file data, performs aggregation sorting (if necessary), and finally returns (partitionId corresponding to reduce, the partition data iterator)

After the data is merge-sorted and written to the final file, the offset of each partition needs to be persisted to the file for each subsequent reduce to obtain its own data according to the offset. The logic of writing the offset is very simple. It is to write the offset to the index file according to the array of partition length obtained earlier, and the corresponding file name is:

def writeIndexFileAndCommit(
      shuffleId: Int,
      mapId: Int,
      lengths: Array[Long],
      dataTmp: File): Unit = {
    val indexFile = getIndexFile(shuffleId, mapId)
    val indexTmp = Utils.tempFileWith(indexFile)
    try {
      val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
      Utils.tryWithSafeFinally {
        // We take in lengths of each block, need to convert it to offsets.
        var offset = 0L
        out.writeLong(offset)
        for (length <- lengths) {
          offset += length
          out.writeLong(offset)
        }
      } 
    ......
    }
  }

Obtain the index file according to shuffleId and mapId and create a file stream for writing the file, and write to the index file in turn according to the offset corresponding to the partition on the reduce side, such as:
0,
length(partition1),
length(partition1)+length(partition2),
length (partition1)+length(partition2)+length(partition3)

Finally, a MapStatus instance is created and returned, which contains the offset corresponding to each partition on the reduce side.

The object will be returned to the DAGScheluer processing on the Driver side and added to the OutputLoc of the corresponding stage. When all tasks of the stage are completed, the results will be registered to the MapOutputTrackerMaster, so that the tasks of the next stage can use it to obtain the shuffle. Metadata information for the result.

So far Shuffle Write is complete!

For the Shuffle Read section, see Shuffle Read Analysis .

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325521316&siteId=291194637