Spark 存储模型

本文简单介绍Spark 的数据存储原理,是《图解Spark核心技术与案例实战》一书的读书笔记。

组件

spark 存储模型是主从模型,其中Driver是Master,Executor是Slave。Driver负责数据的元信息管理,Slave 负责存储数据,执行Driver传递过来的数据操作命令。

Driver

应用启动时,SparkContext会在Driver端创建SparkEnv,在SparkEnv中创建BlockManager和BlockManagerMaster,在BlockManagerMaster里面创建BlockManagerMasterEndPoint来进行通信。

Executor

Executor启动的时候也会创建SparkEnv,Executor的SparkEnv创建了BlockManager和BlockTransferService,在BlockManager初始化的过程中,会加入BlockManagerMasterEndpoint的引用和创建BlockManagerSlaveEndPoint,并会将BlockManagerSlaveEndPoint的引用注册到Driver,使得Driver和Executor可以互相通信。

元数据

BlockManagerMasterEndPoint保存了数据块的元数据,包括三个field
1 blockManagerInfo:
private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]
是一个HashMap,键是blockManagerId,值是BlockManagerInfo,BlockManagerInfo保存了Executor内存使用情况,数据块使用情况,已经被使用的数据块和Executor通信的终端点的引用。
2 blockManagerIdByExecutor:
private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
也是一个HashMap,存放了ExecutorId和BlockManagerId对应的列表
3 blockLocations
private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
block id和持有这个block的block manager id的set的对应,这个可以在查询block位置的时候使用

存储级别

persist() 方法

cache()方法和persist()方法可以用来显式地将数据保存到内存或者磁盘中,其中cache方法是persist()在参数为MEMORY_ONLY时的封装。

  /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
  def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)

  /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
  def cache(): this.type = persist()

persist()函数的实现如下:

  private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
    // TODO: Handle changes of StorageLevel
    // RDD的存储级别一旦设置了之后就不能更改
    if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
      throw new UnsupportedOperationException(
        "Cannot change storage level of an RDD after it was already assigned a level")
    }
    // If this is the first time this RDD is marked for persisting, register it
    // with the SparkContext for cleanups and accounting. Do this only once.
    if (storageLevel == StorageLevel.NONE) {
      sc.cleaner.foreach(_.registerRDDForCleanup(this))
      sc.persistRDD(this)
    }
    storageLevel = newLevel
    this
  }

RDD第一次被计算的时候调用persist函数会根据存储级别参数采取特定的缓存策略,之后就不能修改这个缓存策略。
所有的存储级别如下:
RDD存储级别

存储调用

对数据的存取都是在job执行的过程中才发生,真正的入口在job运行的方法iterator()方法里面的getOrCompute()里面:

  /**
   * Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached.
   */
  private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
    val blockId = RDDBlockId(id, partition.index)
    var readCachedBlock = true
    // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
    // 这个方法由Executor调用
    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
      // 如果数据不在内存,那么就尝试读取检查点结果迭代计算
      readCachedBlock = false
      computeOrReadCheckpoint(partition, context)
    }) match {
    // 读取成功了
      case Left(blockResult) =>
        if (readCachedBlock) {
          val existingMetrics = context.taskMetrics().inputMetrics
          existingMetrics.incBytesRead(blockResult.bytes)
          new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
            override def next(): T = {
              existingMetrics.incRecordsRead(1)
              delegate.next()
            }
          }
        } else {
          new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
        }
      case Right(iter) =>
        new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
    }
  }

读写的逻辑是在getOrElseUpdate里面

  /**
   * Retrieve the given block if it exists, otherwise call the provided `makeIterator` method
   * to compute the block, persist it, and return its values.
   *
   * @return either a BlockResult if the block was successfully cached, or an iterator if the block
   *         could not be cached.
   */
  def getOrElseUpdate[T](
      blockId: BlockId,
      level: StorageLevel,
      classTag: ClassTag[T],
      makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
    // Attempt to read the block from local or remote storage. If it's present, then we don't need
    // to go through the local-get-or-put path.
    // 读取数据
    get(blockId) match {
      case Some(block) =>
        return Left(block)
      case _ =>
        // Need to compute the block.
    }
    // Initially we hold no locks on this block.
    // 写数据入口
    doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
      case None =>
        // doPut() didn't hand work back to us, so the block already existed or was successfully
        // stored. Therefore, we now hold a read lock on the block.
        val blockResult = getLocalValues(blockId).getOrElse {
          // Since we held a read lock between the doPut() and get() calls, the block should not
          // have been evicted, so get() not returning the block indicates some internal error.
          releaseLock(blockId)
          throw new SparkException(s"get() failed for block $blockId even though we held a lock")
        }
        // We already hold a read lock on the block from the doPut() call and getLocalValues()
        // acquires the lock again, so we need to call releaseLock() here so that the net number
        // of lock acquisitions is 1 (since the caller will only call release() once).
        releaseLock(blockId)
        Left(blockResult)
      case Some(iter) =>
        // The put failed, likely because the data was too large to fit in memory and could not be
        // dropped to disk. Therefore, we need to pass the input iterator back to the caller so
        // that they can decide what to do with the values (e.g. process them without caching).
       Right(iter)
    }
  }

读数据

调用的逻辑如下:
RDD读取数据调用图

BlockManage的get方法是读取数据的入口点,在里面会判断数据是否在本地而选择是直接从本地读取还是通过BlockTransferService读取远程数据:

  /**
   * Get a block from the block manager (either local or remote).
   *
   * This acquires a read lock on the block if the block was stored locally and does not acquire
   * any locks if the block was fetched from a remote block manager. The read lock will
   * automatically be freed once the result's `data` iterator is fully consumed.
   */
  def get(blockId: BlockId): Option[BlockResult] = {
    val local = getLocalValues(blockId)
    if (local.isDefined) {
      logInfo(s"Found block $blockId locally")
      return local
    }
    val remote = getRemoteValues(blockId)
    if (remote.isDefined) {
      logInfo(s"Found block $blockId remotely")
      return remote
    }
    None
  }

具体读取逻辑在 getLocalValues和getRemoteValues这两个函数里面,接下来依次看看其实现:

 def getLocalValues(blockId: BlockId): Option[BlockResult] = {
    logDebug(s"Getting local block $blockId")
    // 加读锁
    blockInfoManager.lockForReading(blockId) match {
      case None =>
        logDebug(s"Block $blockId was not found")
        None
      case Some(info) =>
        val level = info.level
        logDebug(s"Level for block $blockId is $level")
        // 从内存读取数据
        if (level.useMemory && memoryStore.contains(blockId)) {
        // 如果序列化了,那么说明是对象数据,使用getValues
          val iter: Iterator[Any] = if (level.deserialized) {
            memoryStore.getValues(blockId).get
          } else {
          // 没序列化,那么是数据流,使用getBytes()
            serializerManager.dataDeserializeStream(
              blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
          }
          val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
          // 返回结果
          Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
        } else if (level.useDisk && diskStore.contains(blockId)) {
        // 存储级别是磁盘,从磁盘读取
          val iterToReturn: Iterator[Any] = {
            // 先读取数据
            val diskBytes = diskStore.getBytes(blockId)
            if (level.deserialized) {
              val diskValues = serializerManager.dataDeserializeStream(
                blockId,
                diskBytes.toInputStream(dispose = true))(info.classTag)
              // 先序列化,然后将数据放入内存,
              maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
            } else {
              // 先将数据放入内存
              val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)
                .map {_.toInputStream(dispose = false)}
                .getOrElse { diskBytes.toInputStream(dispose = true) }
              // 序列化返回的值
              serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
            }
          }
          val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId))
          Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
        } else {
          handleLocalReadFailure(blockId)
        }
    }
  }
内存读取

先查看内存读取的,继续查看memoryStore的getValues和getBytes方法:

  def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
    val entry = entries.synchronized { entries.get(blockId) }
    entry match {
      case null => None
      case e: DeserializedMemoryEntry[_] =>
        throw new IllegalArgumentException("should only call getBytes on serialized blocks")
      case SerializedMemoryEntry(bytes, _, _) => Some(bytes)
    }
  }

  def getValues(blockId: BlockId): Option[Iterator[_]] = {
    val entry = entries.synchronized { entries.get(blockId) }
    entry match {
      case null => None
      case e: SerializedMemoryEntry[_] =>
        throw new IllegalArgumentException("should only call getValues on deserialized blocks")
      case DeserializedMemoryEntry(values, _, _) =>
        val x = Some(values)
        x.map(_.iterator)
    }
  }

发现都是从一个entries的LinkedHashMap里面按照blockId读取数据,可以看出Spark底层是在使用一个LinkedHashMap保存数据。使用LinkedHashMap可以保存键值对的插入顺序,这样在内存不够时,先插入的数据会先溢出到磁盘,实现了FIFO序。

磁盘读取

Spark通过spark.local.dir设置文件存储的目录,默认情况下设置一个一级目录,在这个一级目录下最多创建64个二级目录,目录的名称是00-63,目录中文件的名称是blockId.name这个字段,唯一标识一个块
Spark磁盘的目录形式

  def getBytes(blockId: BlockId): ChunkedByteBuffer = {
    val file = diskManager.getFile(blockId.name)
    // 使用NIO的channel来访问文件
    val channel = new RandomAccessFile(file, "r").getChannel
    Utils.tryWithSafeFinally {
      // For small files, directly read rather than memory map
      // 小文件直接读取到buffer里面
      if (file.length < minMemoryMapBytes) {
        val buf = ByteBuffer.allocate(file.length.toInt)
        channel.position(0)
        while (buf.remaining() != 0) {
          if (channel.read(buf) == -1) {
            throw new IOException("Reached EOF before filling buffer\n" +
              s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
          }
        }
        buf.flip()
        new ChunkedByteBuffer(buf)
      } else {
        // 大文件将channel 映射到内存
        new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length))
      }
    } {
      channel.close()
    }
  }

在读取数据之前,首先使用了getFile来获取保存数据的文件

  /** Looks up a file by hashing it into one of our local subdirectories. */
  // This method should be kept in sync with
  // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getFile().
  def getFile(filename: String): File = {
    // Figure out which local directory it hashes to, and which subdirectory in that
    // filename 传进来的是blockId
    val hash = Utils.nonNegativeHash(filename)
    val dirId = hash % localDirs.length
    // 计算出当前的block数据存在哪个二级目录里面
    val subDirId = (hash / localDirs.length) % subDirsPerLocalDir

    // Create the subdirectory if it doesn't already exist
    // 要是这个二级目录不存在,那么需要创建这个二级目录
    val subDir = subDirs(dirId).synchronized {
      val old = subDirs(dirId)(subDirId)
      if (old != null) {
        old
      } else {
        val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
        if (!newDir.exists() && !newDir.mkdir()) {
          throw new IOException(s"Failed to create local dir in $newDir.")
        }
        subDirs(dirId)(subDirId) = newDir
        newDir
      }
    }
    new File(subDir, filename)
  }
远程读取

在getRemoteValues里面先调用了getRemoteBytes获取到数据。

  /**
   * Get block from remote block managers.
   *
   * This does not acquire a lock on this block in this JVM.
   */
  private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
    getRemoteBytes(blockId).map { data =>
      val values =
        serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))
      new BlockResult(values, DataReadMethod.Network, data.size)
    }
  }

getRemoteBytes首先调用了getLocations获取数据保存的具体位置。
调用getLocations获取数据所在的位置

 /**
   * Return a list of locations for the given block, prioritizing the local machine since
   * multiple block managers can share the same host.
   */
  private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
    // 调用master的getLocations()
    val locs = Random.shuffle(master.getLocations(blockId))
    val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host }
    preferredLocs ++ otherLocs
  }
  
  /** Get locations of the blockId from the driver */
  def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
  // 给driver发送消息获取数据保存的位置
    driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
  }

// 这个是BlockManagerMasterEndPoint收到了GetLocations这样的消息之后调用的方法,是
// 从blockLocations这个HashMap里面查找,返回了持有这个block的所有block manager id
  private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
    if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
  }

接下来看getRemoteBytes()

  /**
   * Get block from remote block managers as serialized bytes.
   */
  def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
    logDebug(s"Getting remote block $blockId")
    require(blockId != null, "BlockId is null")
    var runningFailureCount = 0
    var totalFailureCount = 0
    // 首先要查询这个数据的具体位置,获取到了所有持有这个block 的block manager id
    val locations = getLocations(blockId)
    val maxFetchFailures = locations.size
    var locationIterator = locations.iterator
    while (locationIterator.hasNext) {
      val loc = locationIterator.next()
      logDebug(s"Getting remote block $blockId from $loc")
      val data = try {
      // 尝试读取数据
        blockTransferService.fetchBlockSync(
          loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
      } catch {
        case NonFatal(e) =>
          runningFailureCount += 1
          totalFailureCount += 1

          if (totalFailureCount >= maxFetchFailures) {
            // Give up trying anymore locations. Either we've tried all of the original locations,
            // or we've refreshed the list of locations from the master, and have still
            // hit failures after trying locations from the refreshed list.
            throw new BlockFetchException(s"Failed to fetch block after" +
              s" ${totalFailureCount} fetch failures. Most recent failure cause:", e)
          }

          logWarning(s"Failed to fetch remote block $blockId " +
            s"from $loc (failed attempt $runningFailureCount)", e)

          // If there is a large number of executors then locations list can contain a
          // large number of stale entries causing a large number of retries that may
          // take a significant amount of time. To get rid of these stale entries
          // we refresh the block locations after a certain number of fetch failures
          if (runningFailureCount >= maxFailuresBeforeLocationRefresh) {
            locationIterator = getLocations(blockId).iterator
            logDebug(s"Refreshed locations from the driver " +
              s"after ${runningFailureCount} fetch failures.")
            runningFailureCount = 0
          }

          // This location failed, so we retry fetch from a different one by returning null here
          null
      }

      if (data != null) {
        return Some(new ChunkedByteBuffer(data))
      }
      logDebug(s"The value of block $blockId is null")
    }
    logDebug(s"Block $blockId not found")
    None
  }

fetchBlockSync()方法负责根据BlockManagerId读取数据,需要说明的是BlockManagerId不是个字段,是个class,有host,port,executor id等字段,这个类名容易误解

  /**
   * A special case of [[fetchBlocks]], as it fetches only one block and is blocking.
   *
   * It is also only available after [[init]] is invoked.
   */
  def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = {
    // A monitor for the thread to wait on.
    val result = Promise[ManagedBuffer]()
    fetchBlocks(host, port, execId, Array(blockId),
      new BlockFetchingListener {
        override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
          result.failure(exception)
        }
        override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
          val ret = ByteBuffer.allocate(data.size.toInt)
          ret.put(data.nioByteBuffer())
          ret.flip()
          result.success(new NioManagedBuffer(ret))
        }
      })
    ThreadUtils.awaitResult(result.future, Duration.Inf)
  }

一部下载数据,使用了listener来回调,具体逻辑在fetchBlocks里面,这是个抽象方法,实际上调用了NettyBlockTransferService里面的实现

  override def fetchBlocks(
      host: String,
      port: Int,
      execId: String,
      blockIds: Array[String],
      listener: BlockFetchingListener): Unit = {
    logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
    try {
      val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
        override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
        // 创建通信客户端
          val client = clientFactory.createClient(host, port)
          // 一对一读取数据
          new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
        }
      }

      val maxRetries = transportConf.maxIORetries()
      if (maxRetries > 0) {
        // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
        // a bug in this code. We should remove the if statement once we're sure of the stability.
        new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
      } else {
        blockFetchStarter.createAndStart(blockIds, listener)
      }
    } catch {
      case e: Exception =>
        logError("Exception while beginning fetchBlocks", e)
        blockIds.foreach(listener.onBlockFetchFailure(_, e))
    }
  }

fetchBlocks里的start方法是用rpc向持有block的executor发送消息

  public void start() {
    if (blockIds.length == 0) {
      throw new IllegalArgumentException("Zero-sized blockIds array");
    }
    client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
      @Override
      public void onSuccess(ByteBuffer response) {
       …………
      }
      @Override
      public void onFailure(Throwable e) {
        …………
      }
    });
  }

消息将由对应Executor的NettyBlockRpcServer中的receive收到,并调用getBlockData方法来读取数据

override def receive(
      client: TransportClient,
      rpcMessage: ByteBuffer,
      responseContext: RpcResponseCallback): Unit = {
    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
    logTrace(s"Received request: $message")

    message match {
      case openBlocks: OpenBlocks =>
        val blocks: Seq[ManagedBuffer] =
          openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
        logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
        responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

      case uploadBlock: UploadBlock =>
        ……
    }
  }

getBlockData最后的调用将会判断时候是shuffle之后获取shuffle 输出的数据,如果不是就会转入和本地读取数据的getLocalValues一样的调用,如果是就会根据shuffle的类型选择不同的读取方式,hash 或者排序。

写入数据

写入数据的类方法调用图:
写入数据调用图
getOrElseUpdate()方法里面的putIterator()是写入数据的入口:

 /**
   * Put the given block according to the given level in one of the block stores, replicating
   * the values if necessary.
   *
   * If the block already exists, this method will not overwrite it.
   *
   * @param keepReadLock if true, this method will hold the read lock when it returns (even if the
   *                     block already exists). If false, this method will hold no locks when it
   *                     returns.
   * @return None if the block was already present or if the put succeeded, or Some(iterator)
   *         if the put failed.
   */
  private def doPutIterator[T](
      blockId: BlockId,
      iterator: () => Iterator[T],
      level: StorageLevel,
      classTag: ClassTag[T],
      tellMaster: Boolean = true,
      keepReadLock: Boolean = false): Option[PartiallyUnrolledIterator[T]] = {
    doPut(blockId, level, classTag, tellMaster = tellMaster, keepReadLock = keepReadLock) { info =>
      val startTimeMs = System.currentTimeMillis
      var iteratorFromFailedMemoryStorePut: Option[PartiallyUnrolledIterator[T]] = None
      // Size of the block in bytes
      var size = 0L
      if (level.useMemory) {
        // Put it in memory first, even if it also has useDisk set to true;
        // We will drop it to disk later if the memory store can't hold it.
        if (level.deserialized) {
          memoryStore.putIteratorAsValues(blockId, iterator(), classTag) match {
            case Right(s) =>
              size = s
            case Left(iter) =>
              // Not enough space to unroll this block; drop to disk if applicable
              if (level.useDisk) {
                logWarning(s"Persisting block $blockId to disk instead.")
                diskStore.put(blockId) { fileOutputStream =>
                  serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
                }
                size = diskStore.getSize(blockId)
              } else {
                iteratorFromFailedMemoryStorePut = Some(iter)
              }
          }
        } else { // !level.deserialized
          memoryStore.putIteratorAsBytes(blockId, iterator(), classTag, level.memoryMode) match {
            case Right(s) =>
              size = s
            case Left(partiallySerializedValues) =>
              // Not enough space to unroll this block; drop to disk if applicable
              if (level.useDisk) {
                logWarning(s"Persisting block $blockId to disk instead.")
                diskStore.put(blockId) { fileOutputStream =>
                  partiallySerializedValues.finishWritingToStream(fileOutputStream)
                }
                size = diskStore.getSize(blockId)
              } else {
                iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator)
              }
          }
        }

      } else if (level.useDisk) {
        diskStore.put(blockId) { fileOutputStream =>
          serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
        }
        size = diskStore.getSize(blockId)
      }

      val putBlockStatus = getCurrentBlockStatus(blockId, info)
      val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid
      if (blockWasSuccessfullyStored) {
        // Now that the block is in either the memory, externalBlockStore, or disk store,
        // tell the master about it.
        info.size = size
        if (tellMaster) {
          reportBlockStatus(blockId, info, putBlockStatus)
        }
        Option(TaskContext.get()).foreach { c =>
          c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus)
        }
        logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
        if (level.replication > 1) {
          val remoteStartTime = System.currentTimeMillis
          val bytesToReplicate = doGetLocalBytes(blockId, info)
          try {
            replicate(blockId, bytesToReplicate, level, classTag)
          } finally {
            bytesToReplicate.dispose()
          }
          logDebug("Put block %s remotely took %s"
            .format(blockId, Utils.getUsedTimeMs(remoteStartTime)))
        }
      }
      assert(blockWasSuccessfullyStored == iteratorFromFailedMemoryStorePut.isEmpty)
      iteratorFromFailedMemoryStorePut
    }
  }

doPutIterator里面的逻辑如下:

if 存储级别是内存
    if 需要序列化
       memoryStore.putIteratorAsValues
       if 空间不够展开块且允许写入磁盘
        	diskStore.put(blockId)
       else
        	返回错误提示
     else if 不用序列化
       memoryStore.putIteratorAsBytes
       if 空间不够展开块且允许写入磁盘
       		diskStore.put(blockId)
       else
       	 返回错误提示
 else if 存储级别是磁盘
       diskStore.put(blockId)
更新数据块信息
if  备份数目>1
	replicate(blockId, bytesToReplicate, level, classTag)

可以看出,上面的doPutIterator方法根据不同的配置,选择了不同的方法写入数据,然后更新了数据块的状态,然后做了备份的更新。

写入内存

上面提到了展开块,什么是展开块呢?展开块是在写入数据到内存是不先写入,而是多次写入,每次写入之前首先检查剩余的内存是否足够存放块,不够的话就尝试将内存中已有的数据写入到磁盘,释放空间来放新的数据。接下来结合putIteratorAsValues详细看下:

  /**
   * Attempt to put the given block in memory store as values.
   * 尝试将当前块作为value保存在内存中
   *
   * It's possible that the iterator is too large to materialize and store in memory. To avoid
   * OOM exceptions, this method will gradually unroll the iterator while periodically checking
   * whether there is enough free memory. If the block is successfully materialized, then the
   * temporary unroll memory used during the materialization is "transferred" to storage memory,
   * so we won't acquire more memory than is actually needed to store the block.
   * 有可能iterator太大以至于不能保存到内存中,为了避免OOM,这个方法会逐渐展开iterator并间歇性检查是否有足够的空余内存
   * 如果这个块成功地保存到了内存中,那么这些在保存过程中暂时展开的的内存就成了存储内存,因此我们不会获取多余的内存。
   * @return in case of success, the estimated the estimated size of the stored data. In case of
   *         failure, return an iterator containing the values of the block. The returned iterator
   *         will be backed by the combination of the partially-unrolled block and the remaining
   *         elements of the original input iterator. The caller must either fully consume this
   *         iterator or call `close()` on it in order to free the storage memory consumed by the
   *         partially-unrolled block.
   */
  private[storage] def putIteratorAsValues[T](
      blockId: BlockId,
      values: Iterator[T],
      classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {

    require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")

    // Number of elements unrolled so far
    // 当前已经展开的数据块
    var elementsUnrolled = 0
    // Whether there is still enough memory for us to continue unrolling this block
    // 是否仍有足够的内存来展开数据块
    var keepUnrolling = true
    // Initial per-task memory to request for unrolling blocks (bytes).
    // 每个展开线程初始的内存大小
    val initialMemoryThreshold = unrollMemoryThreshold
    // How often to check whether we need to request more memory
    // 每隔几次检查是否有足够的空余空间
    val memoryCheckPeriod = 16
    // Memory currently reserved by this task for this particular unrolling operation
    // 当前线程保留用来做展开块工作的内存大小
    var memoryThreshold = initialMemoryThreshold
    // Memory to request as a multiple of current vector size
    // 内存增长因子,每次请求的内存大小为(memoryGrowthFactor * vector .size())-memoryThreshold 
    val memoryGrowthFactor = 1.5
    // Keep track of unroll memory used by this particular block / putIterator() operation
    // 展开这个块使用的内存大小
    var unrollMemoryUsedByThisBlock = 0L
    // Underlying vector for unrolling the block
    // 用于追踪数据块展开所使用的内存的大小
    var vector = new SizeTrackingVector[T]()(classTag)

    // Request enough memory to begin unrolling
    // 请求足够的内存做unrolling
    keepUnrolling =
      reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP)

    if (!keepUnrolling) {
      logWarning(s"Failed to reserve initial memory threshold of " +
        s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
    } else {
      unrollMemoryUsedByThisBlock += initialMemoryThreshold
    }

    // Unroll this block safely, checking whether we have exceeded our threshold periodically
    // 安全地展开这个数据库,定期检查剩余内存是否足够
    while (values.hasNext && keepUnrolling) {
      vector += values.next()
      // 每16次检查一次是否超出了分配的内存的大小
      if (elementsUnrolled % memoryCheckPeriod == 0) {
        // If our vector's size has exceeded the threshold, request more memory
        val currentSize = vector.estimateSize()
        // 如果不够
        if (currentSize >= memoryThreshold) {
          val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
          // 申请那日村
          keepUnrolling =
            reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP)
          if (keepUnrolling) {
            unrollMemoryUsedByThisBlock += amountToRequest
          }
          // New threshold is currentSize * memoryGrowthFactor
          memoryThreshold += amountToRequest
        }
      }
      elementsUnrolled += 1
    }

    // 如果成功展开了这个块,估计该块在内存中占的空间的大小
    if (keepUnrolling) {
      // We successfully unrolled the entirety of this block
      val arrayValues = vector.toArray
      vector = null
      val entry =
        new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
      val size = entry.size
      def transferUnrollToStorage(amount: Long): Unit = {
        // Synchronize so that transfer is atomic
        // 将展开所用的内存转为存储的内存,释放掉展开的空间,然后获取内存用于存放block
        memoryManager.synchronized {
          releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount)
          val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP)
          assert(success, "transferring unroll memory to storage memory failed")
        }
      }
      // Acquire storage memory if necessary to store this block in memory.
      // 判断内存是否足够保存这个数据块
      val enoughStorageMemory = {
      // 展开所用的内存小于数据块的大小
        if (unrollMemoryUsedByThisBlock <= size) {
        // 获取额外的空间
          val acquiredExtra =
            memoryManager.acquireStorageMemory(
              blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP)
              // 申请成功就用transferUnrollToStorage分配内存写入block
          if (acquiredExtra) {
            transferUnrollToStorage(unrollMemoryUsedByThisBlock)
          }
          acquiredExtra
        } else { // unrollMemoryUsedByThisBlock > size
          // If this task attempt already owns more unroll memory than is necessary to store the
          // block, then release the extra memory that will not be used.
          // 如果展开使用的内存大于块需要的size,那么先释放多余的内存,然后使用transferUnrollToStorage处理
          val excessUnrollMemory = unrollMemoryUsedByThisBlock - size
          releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory)
          transferUnrollToStorage(size)
          true
        }
      }
      // 如果内存足够,那么写入数据到entry
      if (enoughStorageMemory) {
        entries.synchronized {
          entries.put(blockId, entry)
        }
        logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(
          blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
        Right(size)
      } else {
        assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask,
          "released too much unroll memory")
        Left(new PartiallyUnrolledIterator(
          this,
          unrollMemoryUsedByThisBlock,
          unrolled = arrayValues.toIterator,
          rest = Iterator.empty))
      }
    } else {
      // We ran out of space while unrolling the values for this block
      logUnrollFailureMessage(blockId, vector.estimateSize())
      Left(new PartiallyUnrolledIterator(
        this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
    }
  }

写入磁盘

写入磁盘的方法是distStore.put()方法

def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = {
    if (contains(blockId)) {
      throw new IllegalStateException(s"Block $blockId is already present in the disk store")
    }
    logDebug(s"Attempting to put block $blockId")
    val startTime = System.currentTimeMillis
    val file = diskManager.getFile(blockId)
    val fileOutputStream = new FileOutputStream(file)
    var threwException: Boolean = true
    try {
      writeFunc(fileOutputStream)
      threwException = false
    } finally {
      try {
        Closeables.close(fileOutputStream, threwException)
      } finally {
         if (threwException) {
          remove(blockId)
        }
      }
    }
    val finishTime = System.currentTimeMillis
    logDebug("Block %s stored as %s file on disk in %d ms".format(
      file.getName,
      Utils.bytesToString(file.length()),
      finishTime - startTime))
  }

和读取文件的流程类似,但是是回调了参数中的writerFunc,将数据写入到文件,看一个writerFunc

  /**
   * Finish writing this block to the given output stream by first writing the serialized values
   * and then serializing the values from the original input iterator.
   */
  def finishWritingToStream(os: OutputStream): Unit = {
    // `unrolled`'s underlying buffers will be freed once this input stream is fully read:
    ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
    memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
    redirectableOutputStream.setOutputStream(os)
    while (rest.hasNext) {
      serializationStream.writeObject(rest.next())(classTag)
    }
    serializationStream.close()
  }

调用了writeObject方法来将数据写入到流中

猜你喜欢

转载自blog.csdn.net/fate_killer_liu_jie/article/details/85007606