Spark Broadcast源代码分析

1. Broadcast 简介

Broadcast(广播变量)是只读变量,它会将数据缓存在每个节点上,而不是每个Task去获取它的复制副本。这样可以降低计算过程中的网络开销。

broadcast的基本使用包括创建和读取。
创建

scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)

读取

scala> broadcastVar.value
res0: Array[Int] = Array(1, 2, 3)
2. BroadcastManager初始化

BroadcastManager是用来管理Broadcast,该实例对象是在SparkEnv.scala的create方法中创建的。


private def create(
      conf: SparkConf,
      executorId: String,
      bindAddress: String,
      advertiseAddress: String,
      port: Option[Int],
      isLocal: Boolean,
      numUsableCores: Int,
      ioEncryptionKey: Option[Array[Byte]],
      listenerBus: LiveListenerBus = null,
      mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
    ...
    // 创建broadcastManager
    val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
    // 创建mapOutputTracker
    val mapOutputTracker = if (isDriver) {
      new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
    } else {
      new MapOutputTrackerWorker(conf)
    }
    ...
  }

BroadcastManager构造方法中会调用initialize方法

private def initialize() {
    synchronized {
      if (!initialized) {
      	// 初始化TorrentBroadcastFactory
        broadcastFactory = new TorrentBroadcastFactory
        // 调用TorrentBroadcastFactory的initialize方法
        broadcastFactory.initialize(isDriver, conf, securityManager)
        initialized = true
      }
    }
  }

只是TorrentBroadcastFactory的initialize实际什么都没做而已。

private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
	override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
}
3. 创建broadcast和读取broadcast

创建broadcast

broadcast的创建是由SparkContext.scala的broadcast方法完成的。该方法实际上调用了BroadcastManager的newBroadcast方法。

def broadcast[T: ClassTag](value: T): Broadcast[T] = {
    assertNotStopped()
    require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
      "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
    val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
    val callSite = getCallSite
    logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
    cleaner.foreach(_.registerBroadcastForCleanup(bc))
    bc
  }

newBroadcast方法中继续调用broadcastFactory的newBroadcast方法,实际上调用的是TorrentBroadcastFactory的newBroadcast方法。

def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
    broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
  }

TorrentBroadcastFactory的newBroadcast方法会创建TorrentBroadcast对象。

private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
...
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
    new TorrentBroadcast[T](value_, id)
  }
...
}

在TorrentBroadcast的构造方法中会调用writeBlocks方法,该方法将广播变量的值写入到Driver节点的blockManager中,以便Executor节点获取广播变量的值。

扫描二维码关注公众号,回复: 11569752 查看本文章
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {
  ...
  private val numBlocks: Int = writeBlocks(obj)
  private def writeBlocks(value: T): Int = {
    import StorageLevel._
    val blockManager = SparkEnv.get.blockManager
    // 在Driver中存储广播变量的副本,以便在Driver上运行的任务不会创建广播变量值的副本。
    if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
      throw new SparkException(s"Failed to store $broadcastId in BlockManager")
    }
    //将对象序列化为字节块
    val blocks =
      TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
    if (checksumEnabled) {
      checksums = new Array[Int](blocks.length)
    }
    blocks.zipWithIndex.foreach { case (block, i) =>
      if (checksumEnabled) {
        checksums(i) = calcChecksum(block)
      }
      val pieceId = BroadcastBlockId(id, "piece" + i)
      val bytes = new ChunkedByteBuffer(block.duplicate())
      // 将字节块保存到BlockManager
      if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
        throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
      }
    }
    blocks.length
  }
  }

读取broadcast
broadcast 方法调用 value方法时,会调用TorrentBroadcast的getValue方法,最终会调用readBroadcastBlock方法。

readBroadcastBlock的执行流程如下图所示:
在这里插入图片描述


private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {
  @transient private lazy val _value: T = readBroadcastBlock()
  override protected def getValue() = {
    _value
  }

 private def readBroadcastBlock(): T = Utils.tryOrIOException {
    TorrentBroadcast.synchronized {
      val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
	  // 如果缓存中有,则从缓存获取
      Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
        setConf(SparkEnv.get.conf)
        val blockManager = SparkEnv.get.blockManager
        // 从本地BlockManager获取
        blockManager.getLocalValues(broadcastId) match {
          case Some(blockResult) =>
            if (blockResult.data.hasNext) {
              val x = blockResult.data.next().asInstanceOf[T]
              releaseLock(broadcastId)

              if (x != null) {
              //将数据写入本地缓存
                broadcastCache.put(broadcastId, x)
              }
              x
            } else {
              throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
            }
          case None =>
            logInfo("Started reading broadcast variable " + id)
            val startTimeMs = System.currentTimeMillis()
            // 远程获取数据
            val blocks = readBlocks()
            logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))

            try {
              val obj = TorrentBroadcast.unBlockifyObject[T](
                blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
              val storageLevel = StorageLevel.MEMORY_AND_DISK
              if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
                throw new SparkException(s"Failed to store $broadcastId in BlockManager")
              }

              if (obj != null) {
                //将数据写入本地缓存
                broadcastCache.put(broadcastId, obj)
              }

              obj
            } finally {
              blocks.foreach(_.dispose())
            }
        }
      }
    }
  }

readBlocks方法从会随机选择一个远程节点获取数据,这样做的好处是可以避免大量Executor同时从Driver拉取数据而造成的数据热点问题。

 private def readBlocks(): Array[BlockData] = {
    val blocks = new Array[BlockData](numBlocks)
    val bm = SparkEnv.get.blockManager
	// 随机选择一个远程节点拉取数据
    for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
      val pieceId = BroadcastBlockId(id, "piece" + pid)
      logDebug(s"Reading piece $pieceId of $broadcastId")
      bm.getLocalBytes(pieceId) match {
        case Some(block) =>
          blocks(pid) = block
          releaseLock(pieceId)
        case None =>
          bm.getRemoteBytes(pieceId) match {
            case Some(b) =>
              if (checksumEnabled) {
                val sum = calcChecksum(b.chunks(0))
                if (sum != checksums(pid)) {
                  throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
                    s" $sum != ${checksums(pid)}")
                }
              }
              if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
                throw new SparkException(
                  s"Failed to store $pieceId of $broadcastId in local BlockManager")
              }
              blocks(pid) = new ByteBufferBlockData(b, true)
            case None =>
              throw new SparkException(s"Failed to get $pieceId of $broadcastId")
          }
      }
    }
    blocks
  }
}

猜你喜欢

转载自blog.csdn.net/cl2010abc/article/details/107526672
今日推荐