0%

spark源码-cache和persist

介绍spark的缓存功能

简介

复杂的任务中,某个中间转换结果可能会被多次调用,此时可以使用 spark 的缓存功能,将计算的中间过程缓存在内存或者磁盘中,以便再次使用,减少不必要的计算。

特点

  • 懒加载,只有RDD触发action时才会进行计算并且缓存
  • 5个参数控制存储级别:是否用内存缓存、 是否用磁盘缓存、是否用堆外内存缓存、是否序列化、缓存个数
  • cache() 是 persist() 也是 persist(StorageLevel.MEMORY_ONLY)
  • api: rdd.cache() or rdd.persist()rdd.unpersist()sc.getPersistentRDDs

接着,以MEMORY_ONLY模式为例,从源码中验证这一切,同时加深对 spark内存管理 的理解

流程

cache() or persist(…)

解决一个疑问:rdd1.cache() rdd1 -> rdd2 rdd2.cahce() ,此时 rdd1 和 rdd2 都会被缓存

// cache() 等同于 persist() 等同于 persist(StorageLevel.MEMORY_ONLY) ,也就是仅缓存于存储内存中。
def cache(): this.type = persist()
def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)

// 缓存级别,由5个参数组成
new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication))

def persist(newLevel: StorageLevel): this.type = {
// isLocallyCheckpointed 方法 判断该RDD是否已经标记为 checkpoint,注意不是cache
if (isLocallyCheckpointed) {
persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true)
} else {
persist(newLevel, allowOverride = false)
}
}

private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
// 可以发现当前版本并不支持改变已缓存 RDD 的 StorageLevel (注意RDD1转成RDD2后自然可以改变)
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
throw new UnsupportedOperationException(
"Cannot change storage level of an RDD after it was already assigned a level")
}
// 仅第一次缓存时触发
if (storageLevel == StorageLevel.NONE) {
sc.cleaner.foreach(_.registerRDDForCleanup(this)) // 用于 cleanups
// 写入 persistentRdds,它是一个map(rdd.id, rdd),可调用 sc.getPersistentRDDs 得到
sc.persistRDD(this)
}
// 设置 storageLevel,之前默认为StorageLevel.NONE
storageLevel = newLevel
this
}

getOrCompute(…)

cache() 给 RDD埋了一个属性storageLevel,只有执行行动操作才会真正执行缓存

// RDD的iterator,由于RDD本身懒加载,只要行动操作才会执行
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
// 计算前先检查 storageLevel 是否不为NONE
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context) // 核心
} else {
computeOrReadCheckpoint(split, context)
}
}

private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// 调用blockManager 的 getOrElseUpdate方法,取出或者生成该blockId对应的block数据,返回blockResult
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
// 该函数变量仅生成block时执行
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
case Left(blockResult) =>
// 读 cache :会计数
if (readCachedBlock) {
val existingMetrics = context.taskMetrics().inputMetrics
existingMetrics.incBytesRead(blockResult.bytes)
// blockResul.data 得到迭代器,封装成 InterruptibleIterator,结束!
new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
override def next(): T = {
existingMetrics.incRecordsRead(1)
delegate.next()
}
}
} else {
// 生成 cache:compute 内有自己的计数,这里就不用处理
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
}
case Right(iter) =>
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
}
}

getOrElseUpdate(…)

getOrElseUpdateBlockManager 中的方法

函数参数makeIterator就是我们的computeOrReadCheckpoint方法

def getOrElseUpdate[T](
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[T],
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
// 调用 get[T](blockId) 方法,从各种 store 中找该 block
get[T](blockId)(classTag) match {
// 找到了表明有缓存,直接返回找到的 BlockResult
case Some(block) =>
return Left(block)
case _ =>
// 没缓存,继续执行下一步
}
// Initially we hold no locks on this block.
// 调用 doPutIterator,将 makeIterator 计算的结果存入Block(逻辑概念)中(其实是存入各种store中)
// 写数据时加 读锁
doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
case None =>
// doPut() 正常情况返回None,接着调用 getLocalValues 读取刚写的 block,返回 blockResult
val blockResult = getLocalValues(blockId).getOrElse {
releaseLock(blockId)
throw new SparkException(s"get() failed for block $blockId even though we held a lock")
}
// 放锁
releaseLock(blockId)
Left(blockResult)
case Some(iter) =>
Right(iter)
}
}

补充

persistentRdds

一个map :key 为 rdd.id ,value 为 rdd

private[spark] val persistentRdds = {
val map: ConcurrentMap[Int, RDD[_]] = new MapMaker().weakValues().makeMap[Int, RDD[_]]()
map.asScala
}
def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap

// 用户使用该方法得到它
val ds: collection.Map[Int, RDD[_]] = sc.getPersistentRDDs

unpersist()

取消缓存

private[spark] def unpersistRDD(rddId: Int, blocking: Boolean = true) {
// 通知blockManager删掉属于该RDD的全部block
env.blockManager.master.removeRdd(rddId, blocking)
// 从map中移掉它
persistentRdds.remove(rddId)
listenerBus.post(SparkListenerUnpersistRDD(rddId))
}

小结

当你对spark的存储有一点理解时,本节相对简单。缓存就是将RDD的storageLevel属性改写,并把该RDD加入persistentRdds这个map中。当执行到iterator时触发,如果没有缓存过,则进行计算并写入BLock中,有缓存直接从BLock中提取即可。