0%

spark源码-Broadcast

简介

使用spark时,针对大的、只读、大家都要用的变量,可以使用 broadcast 提高性能:

  • 大的:对其切分后再传输
  • 只读:只有broadcast.value方法
  • 大家都要用:同一个executor的task共享

流程

  • broadcast 对象在driver端写入block中,较简单。
  • 在executor端中读取时,流程复杂一点,如下图所示:
broadcast _1_.png

细节:

  • 成功读到value后,会写入cachedValues中(task共享),它可能会被垃圾回收,关注它的数据结构。
  • broadcast对象的StorageLevel 是 MEMORY_AND_DISK
  • 拉取block时,永远是拉小block(piece,默认4M)并且是乱序拉取,再合并。

sc.broadcast

使用sc.broadcast在driver端创建broadcast对象

// driver端
val listBroadcast: Broadcast[List[String]] = sc.broadcast(listToRemove)

// SparkContext.scala
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
assertNotStopped()
// sc.broadcast(rdd) 错误;sc.broadcast(rdd.collect) 正确
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
val bc = env.broadcastManager.newBroadcast[T](value, isLocal) // 调用 newBroadcast
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}

// BroadcastManage.scala
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
// 注意这个 nextBroadcastId.getAndIncrement(),每个broadcast都有一个ID(0开始每次加1)与之对应
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}

// TorrentBroadcastFactory.scala
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
// 构造 TorrentBroadcast,它就是sc.broadcast(xxxx)的最终返回值
// 重点关注其构造流程
new TorrentBroadcast[T](value_, id)
}

TorrentBroadcast

TorrentBroadcast 继承 Broadcast,下面关注其构造流程

private val broadcastId = BroadcastBlockId(id)
// lazy变量用时再计算,这里初始化所以 pass
@transient private lazy val _value: T = readBroadcastBlock()

// 配置 blockSize 4M 和 compressionCodec
@transient private var compressionCodec: Option[CompressionCodec] = _
@transient private var blockSize: Int = _
private def setConf(conf: SparkConf) {
compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
Some(CompressionCodec.createCodec(conf))
} else {
None
}
blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)
}
setConf(SparkEnv.get.conf)

// broadcastId = "broadcast_" + broadcastId
private val broadcastId = BroadcastBlockId(id)

// 重点,写块并返回块的个数
private val numBlocks: Int = writeBlocks(obj)

// checksum
private var checksumEnabled: Boolean = false
private var checksums: Array[Int] = _

writeBlocks

driver端将 value 写入block,再将 value 切开后再写入block,注意StorageLevel 都是 MEMORY_AND_DISK

至此Broadcast已经准备好。

private def writeBlocks(value: T): Int = {
import StorageLevel._

// driver 端将 value 写入block,调用 blockManager.putSingle
val blockManager = SparkEnv.get.blockManager
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}

// 将value序列化和压缩后,按blockSize切成piece
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
if (checksumEnabled) {
checksums = new Array[Int](blocks.length)
}
blocks.zipWithIndex.foreach { case (block, i) =>
// checksum
if (checksumEnabled) {
checksums(i) = calcChecksum(block)
}

// pieceId = "broadcast_" + broadcastId + "_" + field)
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
// 将 pieces 都写入block,调用 blockManager.putBytes
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
blocks.length
}

readBroadcastBlock

读取 Broadcast

// executor 执行 listBroadcast.value
val filterRDD: RDD[String] = rdd.filter(!listBroadcast.value.contains(_))

// 调用TorrentBroadcast的getValue()
def value: T = {
assertValid()
getValue()
}

// 此时真正执行readBroadcastBlock
@transient private lazy val _value: T = readBroadcastBlock()

// readBroadcastBlock
private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
// broadcastManager 里维护的 cachedValues
val broadcastCache = SparkEnv.get.broadcastManager.cachedValues

// 先看cachedValues中有没有,否则继续
Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
setConf(SparkEnv.get.conf)
val blockManager = SparkEnv.get.blockManager
blockManager.getLocalValues(broadcastId) match {
case Some(blockResult) =>
if (blockResult.data.hasNext) {
// 注意我们的value就是一个单一对象,直接调blockResult.data.next()就可以取出
val x = blockResult.data.next().asInstanceOf[T]
releaseLock(broadcastId)
if (x != null) {
// 加入broadcastCache中
broadcastCache.put(broadcastId, x)
}
x
} else {
throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
}
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
// 关键,拉取小Blocks
val blocks = readBlocks()
logInfo("Reading broadcast variable " + id + " took" Utils.getUsedTimeMs(startTimeMs))
try {
// 合并小blocks并调用putSingle写入block中
val obj = TorrentBroadcast.unBlockifyObject[T](
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}

if (obj != null) {
// 加入broadcastCache中
broadcastCache.put(broadcastId, obj)
}
obj
} finally {
blocks.foreach(_.dispose())
}
}
}
}
}

readBlocks

从本地 or 远程(driver和其它executor)中拉取小blocks

private def readBlocks(): Array[BlockData] = {
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager

// 从driver和executor中拉小blocks
// 特别注意这个Random.shuffle!打乱次序,让大家均匀的拉取,提高整体利用率
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
logDebug(s"Reading piece $pieceId of $broadcastId")
// 先getLocalBytes,因为是并行的有些task可能已经拉取了一些到本地
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
// getRemoteBytes,拉远程的block
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0))
if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
s" $sum != ${checksums(pid)}")
}
}
// 将 小blocks 写入 BlockManager
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
}
blocks
}

unpersist和destroy

unpersist(true)destroy(true)

true : 表示删除时给block加锁

unpersist:删除executor上broadcast的缓存和block。 如果再次使用broadcast,则需要由driver重新发给executor。

destroy:听名字就知道,会干掉所有,包括driver中的。无法再次使用,调用isValid返回false。

补充知识

cachedValues

private[broadcast] val cachedValues = {
new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
}

cachedValues的数据结构是ReferenceMap,是Apache Commons Collections API中的对象。

基于hashmap实现,允许垃圾回收器删除映射。

StrongReference(强) > SoftReference(软) > WeakReference(弱) > PhantomReference(虚)

软引用:内存足够,GC时不回收此Object;内存OOM前,GC回收此Object。常用于Cache
弱引用:不管内存是否足够,GC时Object都会被回收

cachedValues 的 key是强引用value是weak引用。因此value在GC时会被释放。如果以后又要用它,就再从block中读,然后再写入cachedValues(参照我画的图)。

小结

broadcast 相对简单,整体上看,就是将一个所有task都要用到的对象,写入BlockManager中,提高效率,节省内存。