9. spark源代码分析(基于yarn cluster模式)- Task执行,Reduce端读取shuffle数据文件

本系列基于spark-2.4.6
通过上一节的分析,我们了解了Spark中ShuflleMapTask中Map端数据的写入流程,这个章节我们分析下Reduce端是如何读取数据的。
ShulleMapTask.runTask中,有这么一个步骤:

writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

其中rdd.iterator:

  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    
    
    if (storageLevel != StorageLevel.NONE) {
    
    
      getOrCompute(split, context)
    } else {
    
    
      computeOrReadCheckpoint(split, context)
    }
  }

最后都会调用RDD·如下方法:

def compute(split: Partition, context: TaskContext): Iterator[T]

而RDD有多重实现,我们看看RDD中groupBy,返回的是一个ShuffledRDD,而ShuffledRDD中对应的compute实现如下:

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    
    
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

这里的read实现在BlockStoreShuffleReader中:

override def read(): Iterator[Product2[K, C]] = {
    
    
    val wrappedStreams = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      serializerManager.wrapStream,
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
      SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
      SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
      SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

    val serializerInstance = dep.serializer.newInstance()
    val recordIter = wrappedStreams.flatMap {
    
     case (blockId, wrappedStream) =>
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // Update the context task metrics for each record read.
    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map {
    
     record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
    
    
      if (dep.mapSideCombine) {
    
    
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
    
    
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
    
    
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }
    val resultIter = dep.keyOrdering match {
    
    
        val sorter =  new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        // Use completion callback to stop sorter if task was finished/cancelled.
        context.addTaskCompletionListener[Unit](_ => {
    
    
          sorter.stop()
        })
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }

这里首先需要注意下mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),,这是去master获取当前节点需要获取的shuffle数据。会向Master节点发送GetMapOutputStatuses信息。

重要的逻辑在ShuffleBlockFetcherIterator,另外这里需要注意几个参数:

spark.reducer.maxSizeInFlight  --每次请求能够拉取的数据最大大小(为了并行,spark取该值 / 5)
spark.reducer.maxReqsInFlight  -- 每次请求最大拉取块数据的个数

ShuffleBlockFetcherIterator在生成后立马执行初始化方法initialize

private[this] def initialize(): Unit = {
    
    
    context.addTaskCompletionListener[Unit](_ => cleanup())
    val remoteRequests = splitLocalRemoteBlocks()
    fetchRequests ++= Utils.randomize(remoteRequests)
    fetchUpToMaxBytes()
    val numFetches = remoteRequests.size - fetchRequests.size
    fetchLocalBlocks()
  }

首先通过splitLocalRemoteBlocks,划分需要拉取哪些数据:

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    
    
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    val remoteRequests = new ArrayBuffer[FetchRequest]
    for ((address, blockInfos) <- blocksByAddress) {
    
    
      if (address.executorId == blockManager.blockManagerId.executorId) {
    
    
        blockInfos.find(_._2 <= 0) match {
    
    
          case Some((blockId, size)) if size < 0 =>
            throw new BlockException(blockId, "Negative block size " + size)
          case Some((blockId, size)) if size == 0 =>
            throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
          case None => // do nothing.
        }
        localBlocks ++= blockInfos.map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {
    
    
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
    
    
          val (blockId, size) = iterator.next()
          if (size < 0) {
    
    
            throw new BlockException(blockId, "Negative block size " + size)
          } else if (size == 0) {
    
    
            throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
          } else {
    
    
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          }
          if (curRequestSize >= targetRequestSize ||
              curBlocks.size >= maxBlocksInFlightPerAddress) {
    
    
            remoteRequests += new FetchRequest(address, curBlocks)
          }
        }
        if (curBlocks.nonEmpty) {
    
    
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    remoteRequests
  }

可以看到这里会区分需要拉取的数据是本地数据还是远程数据(这里数据用Block表示),如果是本地数据则会放入把数据对应的BlockId放入到localBlocks集合中。如果是远端的数据,这里是按照一个节点一个节点来遍历节点下的所有数据,是按照节点来拉取节点上的所有数据。这里会判断当前节点遍历的Block,如果遍历到当前Block,所有Block的大小 >= targetRequestSize 或者Block的个数大于maxBlocksInFlightPerAddress的时候,则会将已经遍历当前节点的Block放到一次请求中去拉取数据,这里的targetRequestSize是前面说的"spark.reducer.maxSizeInFlight/5这里除以5是为了增加并行度maxBlocksInFlightPerAddress则是每次请求一个节点额数据最多请求多少个Block,默认情况下这个是Int.MAX.到这里就将本地和远端需要拉取的数据分好了,然后会通过fetchUpToMaxBytes获取对应节点上的Block的信息,然后拉取Block数据,
发送拉取数据请求sendRequest,这里需要注意有一个处理逻辑:

    if (req.size > maxReqSizeShuffleToMem) {
    
    
      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
        blockFetchingListener, this)
    } else {
    
    
      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
        blockFetchingListener, null)
    }

这里会判断拉取数据的大小,如果待拉取的数据大小> maxReqSizeShuffleToMem ,那么会将数据写入到本地磁盘,这里的maxReqSizeShuffleToMem通过spark.maxRemoteBlockSizeFetchToMem来配置,默认是Int.MaxValue - 512 字节

最终会调用NettyBlockTransferService.fetchBlocks:

override def fetchBlocks(
      host: String,
      port: Int,
      execId: String,
      blockIds: Array[String],
      listener: BlockFetchingListener,
      tempFileManager: DownloadFileManager): Unit = {
    
    
    try {
    
    
      val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
    
    
        override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
    
    
          val client = clientFactory.createClient(host, port)
          new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
            transportConf, tempFileManager).start()
        }
      }
      val maxRetries = transportConf.maxIORetries()
      if (maxRetries > 0) {
    
    
        new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
      } else {
    
    
        blockFetchStarter.createAndStart(blockIds, listener)
      }
    } catch {
    
    
  }

可以看到最后启动了OneForOneBlockFetcher:

public void start() {
    
    
    if (blockIds.length == 0) {
    
    
      throw new IllegalArgumentException("Zero-sized blockIds array");
    }
    client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
    
    
      @Override
      public void onSuccess(ByteBuffer response) {
    
    
        try {
    
    
          streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
          for (int i = 0; i < streamHandle.numChunks; i++) {
    
    
            if (downloadFileManager != null) {
    
    
              client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
                new DownloadCallback(i));
            } else {
    
    
              client.fetchChunk(streamHandle.streamId, i, chunkCallback);
            }
          }
        } catch (Exception e) {
    
    
          failRemainingBlocks(blockIds, e);
        }
      }
      @Override
      public void onFailure(Throwable e) {
    
    
        failRemainingBlocks(blockIds, e);
      }
    });
  }

这里先是给要拉取数据的节点发送了一个OpenBlocks信息,如果成功后,则会调用TransportClient获取对饮的数据,这里会判断downloadFileManager是否为空,就是上面说的这个条件如果待拉取的数据大小> maxReqSizeShuffleToMem,如果满足则要写文件downloadFileManager不为空,否则直接写内存。

  • 写文件方式最后底层是发送了一个StreamRequest请求
  • 写内存方式发送了一个ChunkFetchRequest请求

同时,当节点返回成功之后,会通过对应Callback进行处理:

public void stream(String streamId, StreamCallback callback) {
    
    
    StdChannelListener listener = new StdChannelListener(streamId) {
    
    
      void handleFailure(String errorMsg, Throwable cause) throws Exception {
    
    
        callback.onFailure(streamId, new IOException(errorMsg, cause));
      }
    };
    synchronized (this) {
    
    
      handler.addStreamCallback(streamId, callback);
      channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener);
    }
  }
  public void fetchChunk(
      long streamId,
      int chunkIndex,
      ChunkReceivedCallback callback) {
    
    
    StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
    StdChannelListener listener = new StdChannelListener(streamChunkId) {
    
    
      void handleFailure(String errorMsg, Throwable cause) {
    
    
        handler.removeFetchRequest(streamChunkId);
        callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
      }
    };
    handler.addFetchRequest(streamChunkId, callback);
    channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener);
  }

到这里reduce单已经将请求发送出去,接下来我们看下数据端怎么处理请求


首先是对应OpenBlocks请求,最后在NettyBlockRpcServer进行处理:

override def receive(
      client: TransportClient,
      rpcMessage: ByteBuffer,
      responseContext: RpcResponseCallback): Unit = {
    
    
    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
    logTrace(s"Received request: $message")

    message match {
    
    
      case openBlocks: OpenBlocks =>
        val blocksNum = openBlocks.blockIds.length
        val blocks = for (i <- (0 until blocksNum).view)
          yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
          client.getChannel)
        logTrace(s"Registered streamId $streamId with $blocksNum buffers")
        responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)

      case uploadBlock: UploadBlock =>
        // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
        val (level: StorageLevel, classTag: ClassTag[_]) = {
    
    
          serializer
            .newInstance()
            .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
            .asInstanceOf[(StorageLevel, ClassTag[_])]
        }
        val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
        val blockId = BlockId(uploadBlock.blockId)
        logDebug(s"Receiving replicated block $blockId with level ${level} " +
          s"from ${client.getSocketAddress}")
        blockManager.putBlockData(blockId, data, level, classTag)
        responseContext.onSuccess(ByteBuffer.allocate(0))
    }
  }

这里会对每个请求注册一个StreamId和对应的StreamState,返回个拉取端一个StreamHandle信息,包含了StreamId和Block的个数。在开始的时候会把每个要拉取的Block的数据读取出来通过getBlockData实现

  override def getBlockData(blockId: BlockId): ManagedBuffer = {
    
    
    if (blockId.isShuffle) {
    
    
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
    } else {
    
    
      getLocalBytes(blockId) match {
    
    
        case Some(blockData) =>
          new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true)
        case None =>
          reportBlockStatus(blockId, BlockStatus.empty)
          throw new BlockNotFoundException(blockId.toString)
      }
    }
  }

这里我们是reduce读取,blockId.isShuffle=true

val shuffleBlockResolver = shuffleManager.shuffleBlockResolver
      val buf = new ChunkedByteBuffer( shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
      Some(new ByteBufferBlockData(buf, true))

最后通过IndexShuffleBlockResolver来进行读取,这也就是上一节我们说的,Map端的写入同时会生成一个索引文件,这里会通过所以文件获取对应数据的信息:

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
    
    
    val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
    val channel = Files.newByteChannel(indexFile.toPath)
    channel.position(blockId.reduceId * 8L)
    val in = new DataInputStream(Channels.newInputStream(channel))
    try {
    
    
      val offset = in.readLong()
      val nextOffset = in.readLong()
      val actualPosition = channel.position()
      val expectedPosition = blockId.reduceId * 8L + 16
      if (actualPosition != expectedPosition) {
    
    
       ....
       }
      new FileSegmentManagedBuffer(
        transportConf,
        getDataFile(blockId.shuffleId, blockId.mapId),
        offset,
        nextOffset - offset)
    } finally {
    
    
      in.close()
    }
  }

最终返回的是一个FileSegmentManagedBuffer.最后返回一个StreamHandle给到客户端。
这里上来就是根据reduceId去索引文件中获取对应reduce端需要拉取数据的位置和数据的大小,后面读取数据的时候,根据这个位置信息读取数据文件中对应的数据
可以看到发送OpenBlocks只是给数据端生成FileSegmentManagedBuffer,知道需要拉取的是哪些数据,并没有其他操作。


然后真正拉取数据则是发送ChunkFetchRequest请求,我们看下是怎么处理的:
TransportRequestHandler会对这些请求进行处理:

  public void handle(RequestMessage request) {
    
    
    if (request instanceof ChunkFetchRequest) {
    
    
      processFetchRequest((ChunkFetchRequest) request);
    } else if (request instanceof RpcRequest) {
    
    
      processRpcRequest((RpcRequest) request);
    } else if (request instanceof OneWayMessage) {
    
    
      processOneWayMessage((OneWayMessage) request);
    } else if (request instanceof StreamRequest) {
    
    
      processStreamRequest((StreamRequest) request);
    } else if (request instanceof UploadStream) {
    
    
      processStreamUpload((UploadStream) request);
    } else {
    
    
      throw new IllegalArgumentException("Unknown request type: " + request);
    }
  }

我们先来看下ChunkFetchRequest处理:

private void processFetchRequest(final ChunkFetchRequest req) {
    
    
    long chunksBeingTransferred = streamManager.chunksBeingTransferred();
    if (chunksBeingTransferred >= maxChunksBeingTransferred) {
    
    
      logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
        chunksBeingTransferred, maxChunksBeingTransferred);
      channel.close();
      return;
    }
    ManagedBuffer buf;
    try {
    
    
      streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
      buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
    } catch (Exception e) {
    
    
      respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
      return;
    }
    streamManager.chunkBeingSent(req.streamChunkId.streamId);
    respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> {
    
    
      streamManager.chunkSent(req.streamChunkId.streamId);
    });
  }

这里首先会判断当前steam的数据是否已经拉取完毕,如果拉取完毕直接关闭通道。然后会获取对应chunk块对应的ManagedBuffer,我们上面知道,这里返回的就是一个FileSegmentManagedBuffer,但是我们详细看这个Buffer,并没有任何数据相关,那么数据是怎么读取传输回去的呢 ?关键就在Netty的编解码中。这里数据端开启server是通过NettyBlockTransferService,其创建createServer方法最后生成一个TransportServer,初始化的时候会调用init方法,其初始化Netty的pipline如下:

public TransportChannelHandler initializePipeline(
      SocketChannel channel,
      RpcHandler channelRpcHandler) {
    
    
    try {
    
    
      TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
      channel.pipeline()
        .addLast("encoder", ENCODER)
        .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
        .addLast("decoder", DECODER)
        .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
        .addLast("handler", channelHandler);
      return channelHandler;
    } catch (RuntimeException e) {
    
    
   
      throw e;
    }
  }

这里的编码实现为:MessageEncoder,在其encode方法中,会调用FileSegmentManagedBuffer.convertToNetty方法:

public Object convertToNetty() throws IOException {
    
    
    if (conf.lazyFileDescriptor()) {
    
    
      return new DefaultFileRegion(file, offset, length);
    } else {
    
    
      FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ);
      return new DefaultFileRegion(fileChannel, offset, length);
    }
  }

可以看到这里返回的是一个DefaultFileRegion,还是没有将文件转化成流,继续看MessageEncoder.encode:

public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) throws Exception {
    
    
    Object body = null;
    long bodyLength = 0;
    boolean isBodyInFrame = false;
    if (in.body() != null) {
    
    
      try {
    
    
        bodyLength = in.body().size();
        body = in.body().convertToNetty();
        isBodyInFrame = in.isBodyInFrame();
      } catch (Exception e) {
    
    
        in.body().release();
        if (in instanceof AbstractResponseMessage) {
    
    
          AbstractResponseMessage resp = (AbstractResponseMessage) in;
          // Re-encode this message as a failure response.
          String error = e.getMessage() != null ? e.getMessage() : "null";
          logger.error(String.format("Error processing %s for client %s",
            in, ctx.channel().remoteAddress()), e);
          encode(ctx, resp.createFailureResponse(error), out);
        } else {
    
    
          throw e;
        }
        return;
      }
    }
    Message.Type msgType = in.type();
    int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
    long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0);
    ByteBuf header = ctx.alloc().heapBuffer(headerLength);
    header.writeLong(frameLength);
    msgType.encode(header);
    in.encode(header);
    if (body != null) {
    
    
      out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
    } else {
    
    
      out.add(header);
    }
  }

最后返回的out中是一个MessageWithHeader,而MessageWithHeader实现了netty的FileRegion接口,当进行网络传输的时候,会调用FileRegion.transferTo方法,在MessageWithHeader实现如下:

public long transferTo(final WritableByteChannel target, final long position) throws IOException {
    
    
    Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
    long writtenHeader = 0;
    if (header.readableBytes() > 0) {
    
    
      writtenHeader = copyByteBuf(header, target);
      totalBytesTransferred += writtenHeader;
      if (header.readableBytes() > 0) {
    
    
        return writtenHeader;
      }
    }
    long writtenBody = 0;
    if (body instanceof FileRegion) {
    
    
      writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength);
    } else if (body instanceof ByteBuf) {
    
    
      writtenBody = copyByteBuf((ByteBuf) body, target);
    }
    totalBytesTransferred += writtenBody;
    return writtenHeader + writtenBody;
  }

最后还是调用FileRegion.transferTo这里就是我们上面生成的DefaultFileRegion。其实现就是通过零拷贝将文件中内容传输到网络中。到此数据就完成了传输


接下来我们在回到数据拉取端,上面数据端返回了一个ChunkFetchSuccess,然后在拉取端TransportResponseHandler进行处理:

if (message instanceof ChunkFetchSuccess) {
    
    
      ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
      ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
      if (listener == null) {
    
    
        logger.warn("Ignoring response for block {} from {} since it is not outstanding",
          resp.streamChunkId, getRemoteAddress(channel));
        resp.body().release();
      } else {
    
    
        outstandingFetches.remove(resp.streamChunkId);
        listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
        resp.body().release();
      }
    }

这里的lister就是上面我们传入的ChunkCallback,onsuccess方法如下:

  public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
    
    
      listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
    }

这里的litener使我们前面传入,实现如下:

val blockFetchingListener = new BlockFetchingListener {
    
    
      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
    
    
        ShuffleBlockFetcherIterator.this.synchronized {
    
    
          if (!isZombie) {
    
    
            buf.retain()
            remainingBlocks -= blockId
            results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
              remainingBlocks.isEmpty))
          }
        }
      }

      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
    
    
        results.put(new FailureFetchResult(BlockId(blockId), address, e))
      }
    }

可以看到,这里没有做任何特殊处理,只是通过返回的结果,实例化了一个SuccessFetchResult。这里private[this] val results = new LinkedBlockingQueue[FetchResult]只是一个LinkedBlockingQueue缓存。

好,到这里我们拉取数据分析了这么多,但是数据并没有进行实际的拉取,那么在什么时候拉取的呢?
这里的拉取数据实现是一个ShuffleBlockFetcherIterator,在其迭代方法next实现中实现了数据拉取:

override def next(): (BlockId, InputStream) = {
    
    
    if (!hasNext) {
    
    
      throw new NoSuchElementException
    }
    numBlocksProcessed += 1
    var result: FetchResult = null
    var input: InputStream = null
    while (result == null) {
    
    
      val startFetchWait = System.currentTimeMillis()
      result = results.take()
      val stopFetchWait = System.currentTimeMillis()
      shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

      result match {
    
    
        case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
          if (address != blockManager.blockManagerId) {
    
    
            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
            shuffleMetrics.incRemoteBytesRead(buf.size)
            if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
    
    
              shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
            }
            shuffleMetrics.incRemoteBlocksFetched(1)
          }
          if (!localBlocks.contains(blockId)) {
    
    
            bytesInFlight -= size
          }
          if (isNetworkReqDone) {
    
    
            reqsInFlight -= 1
            logDebug("Number of requests in flight " + reqsInFlight)
          }

          if (buf.size == 0) 
            throwFetchFailedException(blockId, address, new IOException(msg))
          }

          val in = try {
    
    
            buf.createInputStream()
          } catch {
    
    
            case e: IOException =>
              assert(buf.isInstanceOf[FileSegmentManagedBuffer])
              logError("Failed to create input stream from local block", e)
              buf.release()
              throwFetchFailedException(blockId, address, e)
          }
          var isStreamCopied: Boolean = false
          try {
    
    
            input = streamWrapper(blockId, in)
            if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
    
    
              isStreamCopied = true
              val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
              Utils.copyStream(input, out, closeStreams = true)
              input = out.toChunkedByteBuffer.toInputStream(dispose = true)
            }
          } catch {
    
    
            case e: IOException =>
              buf.release()
              if (buf.isInstanceOf[FileSegmentManagedBuffer]
                || corruptedBlocks.contains(blockId)) {
    
    
                throwFetchFailedException(blockId, address, e)
              } else {
    
    
                corruptedBlocks += blockId
                fetchRequests += FetchRequest(address, Array((blockId, size)))
                result = null
              }
          } finally {
    
    
            if (isStreamCopied) {
    
    
              in.close()
            }
          }
        case FailureFetchResult(blockId, address, e) =>
          throwFetchFailedException(blockId, address, e)
      }
      fetchUpToMaxBytes()
    }

    currentResult = result.asInstanceOf[SuccessFetchResult]
    (currentResult.blockId, new BufferReleasingInputStream(input, this))
  }

这里将接收到的数据写入到了ChunkedByteBufferOutputStream中,然后将输出流改变为输入流返回给上游。
这里我们之前分析的上游就是BlockStoreShuffleReader,在其read方法中会迭代调用上述数据,执行聚合算子Aggregator中插入到一个Map中:

def combineValuesByKey(
      iter: Iterator[_ <: Product2[K, V]],
      context: TaskContext): Iterator[(K, C)] = {
    
    
    val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
    combiners.insertAll(iter)
    updateMetrics(context, combiners)
    combiners.iterator
  }

  def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
    
    
    var curEntry: Product2[K, V] = null
    val update: (Boolean, C) => C = (hadVal, oldVal) => {
    
    
      if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
    }

    while (entries.hasNext) {
    
    
      curEntry = entries.next()
      val estimatedSize = currentMap.estimateSize()
      if (estimatedSize > _peakMemoryUsedBytes) {
    
    
        _peakMemoryUsedBytes = estimatedSize
      }
      if (maybeSpill(currentMap, estimatedSize)) {
    
    
        currentMap = new SizeTrackingAppendOnlyMap[K, C]
      }
      currentMap.changeValue(curEntry._1, update)
      addElementsRead()
    }
  }
   override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    
    
    val newValue = super.changeValue(key, updateFunc)
    super.afterUpdate()
    newValue
  }
  def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    
    
    assert(!destroyed, destructionMessage)
    val k = key.asInstanceOf[AnyRef]
    if (k.eq(null)) {
    
    
      if (!haveNullValue) {
    
    
        incrementSize()
      }
      nullValue = updateFunc(haveNullValue, nullValue)
      haveNullValue = true
      return nullValue
    }
    var pos = rehash(k.hashCode) & mask
    var i = 1
    while (true) {
    
    
      val curKey = data(2 * pos)
      if (curKey.eq(null)) {
    
    
        val newValue = updateFunc(false, null.asInstanceOf[V])
        data(2 * pos) = k
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        incrementSize()
        return newValue
      } else if (k.eq(curKey) || k.equals(curKey)) {
    
    
        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        return newValue
      } else {
    
    
        val delta = i
        pos = (pos + delta) & mask
        i += 1
      }
    }
    null.asInstanceOf[V] // Never reached but needed to keep compiler happy
  }

这里插入到map的时候,会根据传入的算子对数据进行聚合运算。

分析到这里,我们简单总结一下:

  1. 当我们RDD遇到类似reparation这种算子的时候,通过BlockStoreShuffleReader去读取shffle数据
  2. BlockStoreShuffleReader首先回去master获取当前节点需要拉取的数据
  3. 然后通过ShuffleBlockFetcherIterator去进行数据拉取
  4. ShuffleBlockFetcherIterator首先会区分其他节点和本地节点数据,本地节点数据直接读取,其他节点需要通过网络传输
  5. ShuffleBlockFetcherIterator获取其他节点数据发送FetchRequest(通过NettyBlockTransferService来获取数据),在发送FetchRequest之前,首先会发送OpenBlocks请求(通过OneForOneBlockFetcher),返回的响应数据中会给出需要拉取数据的相关信息
  6. 数据端收到OpenBlocks请求后,会根据请求中数据信息获取相关索引文件,获取索引文件中对应的要拉取数据在数据文件中位移,生成FileSegmentManagedBuffer集合,同时封装成一个StreamHandle返回给客户端,StreamHandle相当于是一个包含了此次数据传输会话信息
  7. 拉区端收到返回的信息后开始发送FetchRequest给数据端
  8. 数据端收到FetchRequest之后,根据StreamHandle的信息找到之前OpenBlocks请求生成的FileSegmentManagedBuffer集合,返回给客户端,这里需要注意的是,返回FileSegmentManagedBuffer会通过单独的MessageEncoder来进行处理,最后是转换成了Netty文件传输
  9. 拉取端获取到数据后,根据相应的算子把数据放入到一个Map中,如果超过一定容量也会溢写到磁盘,如果需要排序,最后会将磁盘和缓存中的数据读取出来进行排序
  10. .结束

猜你喜欢

转载自blog.csdn.net/LeoHan163/article/details/120979442