1. Broadcast 简介
Broadcast(广播变量)是只读变量,它会将数据缓存在每个节点上,而不是每个Task去获取它的复制副本。这样可以降低计算过程中的网络开销。
broadcast的基本使用包括创建和读取。
创建
scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
读取
scala> broadcastVar.value
res0: Array[Int] = Array(1, 2, 3)
2. BroadcastManager初始化
BroadcastManager是用来管理Broadcast,该实例对象是在SparkEnv.scala的create方法中创建的。
private def create(
conf: SparkConf,
executorId: String,
bindAddress: String,
advertiseAddress: String,
port: Option[Int],
isLocal: Boolean,
numUsableCores: Int,
ioEncryptionKey: Option[Array[Byte]],
listenerBus: LiveListenerBus = null,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
...
// 创建broadcastManager
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
// 创建mapOutputTracker
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
} else {
new MapOutputTrackerWorker(conf)
}
...
}
BroadcastManager构造方法中会调用initialize方法
private def initialize() {
synchronized {
if (!initialized) {
// 初始化TorrentBroadcastFactory
broadcastFactory = new TorrentBroadcastFactory
// 调用TorrentBroadcastFactory的initialize方法
broadcastFactory.initialize(isDriver, conf, securityManager)
initialized = true
}
}
}
只是TorrentBroadcastFactory的initialize实际什么都没做而已。
private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
}
3. 创建broadcast和读取broadcast
创建broadcast
broadcast的创建是由SparkContext.scala的broadcast方法完成的。该方法实际上调用了BroadcastManager的newBroadcast方法。
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
assertNotStopped()
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}
newBroadcast方法中继续调用broadcastFactory的newBroadcast方法,实际上调用的是TorrentBroadcastFactory的newBroadcast方法。
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}
TorrentBroadcastFactory的newBroadcast方法会创建TorrentBroadcast对象。
private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
...
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
new TorrentBroadcast[T](value_, id)
}
...
}
在TorrentBroadcast的构造方法中会调用writeBlocks方法,该方法将广播变量的值写入到Driver节点的blockManager中,以便Executor节点获取广播变量的值。
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
...
private val numBlocks: Int = writeBlocks(obj)
private def writeBlocks(value: T): Int = {
import StorageLevel._
val blockManager = SparkEnv.get.blockManager
// 在Driver中存储广播变量的副本,以便在Driver上运行的任务不会创建广播变量值的副本。
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
//将对象序列化为字节块
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
if (checksumEnabled) {
checksums = new Array[Int](blocks.length)
}
blocks.zipWithIndex.foreach { case (block, i) =>
if (checksumEnabled) {
checksums(i) = calcChecksum(block)
}
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
// 将字节块保存到BlockManager
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
blocks.length
}
}
读取broadcast
broadcast 方法调用 value方法时,会调用TorrentBroadcast的getValue方法,最终会调用readBroadcastBlock方法。
readBroadcastBlock的执行流程如下图所示:
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
@transient private lazy val _value: T = readBroadcastBlock()
override protected def getValue() = {
_value
}
private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
// 如果缓存中有,则从缓存获取
Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
setConf(SparkEnv.get.conf)
val blockManager = SparkEnv.get.blockManager
// 从本地BlockManager获取
blockManager.getLocalValues(broadcastId) match {
case Some(blockResult) =>
if (blockResult.data.hasNext) {
val x = blockResult.data.next().asInstanceOf[T]
releaseLock(broadcastId)
if (x != null) {
//将数据写入本地缓存
broadcastCache.put(broadcastId, x)
}
x
} else {
throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
}
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
// 远程获取数据
val blocks = readBlocks()
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
try {
val obj = TorrentBroadcast.unBlockifyObject[T](
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
if (obj != null) {
//将数据写入本地缓存
broadcastCache.put(broadcastId, obj)
}
obj
} finally {
blocks.foreach(_.dispose())
}
}
}
}
}
readBlocks方法从会随机选择一个远程节点获取数据,这样做的好处是可以避免大量Executor同时从Driver拉取数据而造成的数据热点问题。
private def readBlocks(): Array[BlockData] = {
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager
// 随机选择一个远程节点拉取数据
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
logDebug(s"Reading piece $pieceId of $broadcastId")
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0))
if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
s" $sum != ${checksums(pid)}")
}
}
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
}
blocks
}
}