0%

spark源码-shuffle之BlockStoreShuffleReader

spark的唯一 ShuffleReader :BlockStoreShuffleReader

引入

之前提过,getShuffleWrite 只在 ShuffleMapTask 中出现。

那么getShuffleRead呢?由于不管哪个 Task 都需要读数据,于是就把该步骤封装在RDD的computer方法中。以下是ShuffleRDD中的computer方法。通过调用 shuffleManagergetReader 就获取了本文的主角BlockStoreShuffleReader

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
// 是不是觉得 split.index, split.index + 1 很怪,后面会发现这是怎么回事
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}

大致流程

  1. 获取初始迭代器ShuffleBlockFetcherIteratorIterator[(BlockId, InputStream)

  2. 反序列化出实例,并生成 k,v 迭代器Iterator[(key, value)]

  3. 添加readMetrics,再封装成InterruptibleIterator

  4. 进行聚合操作。分有无聚合,当有聚合时,又分是否执行过map聚合

  5. 进行排序操作。分有无排序。

BlockStoreShuffleReader 要干的事其实很容易理解,就是把不同 block 中同一分区的record,拉到指定的 reducer 中,再对它进行聚合和排序即可。一个 reducer 处理一个分区。

源码

获取初始迭代器ShuffleBlockFetcherIterator

// 获取初始迭代器 ShuffleBlockFetcherIterator
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient, // 默认是 NettyBlockTransferService,如果使用外部shuffle系统则使用 ExternalShuffleClient
blockManager,
// 通过它得到 2元tuple : (BlockManagerId, Seq[(BlockId, BlockSize)])
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

这是 mapOutputTracker.getMapSizesByExecutorId 里的关键步骤

def convertMapStatuses(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]]
// mapTask 的个数决定了 Seq[(BlockId, BlockSize)] 内元素的个数,很好理解
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
} else {
// 左闭右开,因此前面的 (split.index, split.index + 1)中的 split.index + 1,并没有什么软用
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, part), size))
}
}
}
}
splitsByAddress.iterator
}

ShuffleBlockFetcherIterator 这个迭代器的具体实现比较复杂,简单介绍下:

  • 分本地数据块 和 远程数据块

  • 本地数据块直接调用 BlockManager.getBlockData

  • 远程数据块采用Netty通过网络获取

下面是 IndexShuffleBlockResolver 中的 getBlockData。我们的索引文件(indexFile)终于派上用场

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)

val channel = Files.newByteChannel(indexFile.toPath)
// 根据reduceId选择索引
channel.position(blockId.reduceId * 8L)
val in = new DataInputStream(Channels.newInputStream(channel))
try {
val offset = in.readLong()
val nextOffset = in.readLong()
val actualPosition = channel.position()
val expectedPosition = blockId.reduceId * 8L + 16
if (actualPosition != expectedPosition) {
throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
s"expected $expectedPosition but actual position was $actualPosition.")
}
new FileSegmentManagedBuffer(
transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
offset,
nextOffset - offset)
} finally {
in.close()
}
}

反序列化出实例,并生成 k,v 迭代器

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

添加readMetrics,再封装成InterruptibleIterator

readMetrics 里都是些记录数据,用于监控展示

InterruptibleIterator 封装了任务中断功能

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1) // sparkUI 里的 record 数
record
},
context.taskMetrics().mergeShuffleReadMetrics())

val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

进行聚合操作

分有无聚合,当有聚合时,又分是否在 mapTask 时执行过map聚合

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
// 有 mapSideCombine,就是聚合(K, C)
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// 无 mapSideCombine,就是聚合(K, V)
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
// 不聚合
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

进行排序操作

分是否需要排序。如果需要,就使用ExternalSorter在分区内部进行排序

val resultIter = dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener[Unit](_ => {
sorter.stop()
})
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}