0%

spark源码-shuffle之SortShuffleWriter

spark 三大ShuffleWriter 之 SortShuffleWriter

特点

  • 最大特点就是支持 map-side aggregation

  • 最基础的 ShuffleWriter。当另外两种用不了时才选它~

大致流程

首先记住有3种情况:含 aggregator 和 ordering含 aggregator 但不含 ordering前两种都不含

为啥没有只含ordering的情况呢?因为不含aggregator就不做排序,永远记住ShuffleWriter阶段的排序只是为了使聚合更舒服

  1. 选择 sorter:情况1和2 选同一种sorter,情况3选另一种。我把它们称为 分支1 和 分支2

  2. 读取数据:分支1把数据读进PartitionedAppendOnlyMap(把同一分区key相同的聚合),分支2把数据读进PartitionedPairBuffer(简单放入)

  3. 数据数量达到阈值发生spill:这个spill文件整体是按分区顺序堆叠的。不同点是分区内部数据情况:情况1按 ordering 排序;情况2按 key 的 hash值 排序(这个排序只是为了方便聚合);情况3 不排序

  4. 合并spill文件和内存中未spill的文件,并返回分区长度数组:情况1先归并排序再聚合;情况2只聚合;情况3啥都不干

  5. 根据分区长度数组生成索引文件

  6. 封装信息到MapStatus返回

总的来说就是一个 task 生成一个由 spill 文件合并形成的且聚合了的大文件和一个索引文件

由于情况2更复杂点,以情况2为示例:
SortShuffleWriter.jpg

源码

write(…)

override def write(records: Iterator[Product2[K, V]]): Unit = {
// 流程1:根据是否需要 mapSideCombine 选择不同的 sorter
sorter = if (dep.mapSideCombine) {
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
// 注意 ordering = None,官方解释的很清楚了,对吧
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
// 流程2-3:读取数据 与 spill
sorter.insertAll(records)


// 文件名 "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
try {
// 流程4: 合并 spill文件 和 内存中未spill的文件,并返回分区长度数组
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
// 流程5: 根据分区长度数组生成索引文件。这个步骤都是一样的,参考 BypassMergeSortShuffleWriter 篇
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
// 流程6: 封装信息到 MapStatus 返回
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
}

insertAll(…)

流程2-3:读取数据 与 spill

主要关注 PartitionedAppendOnlyMap PartitionedPairBuffer

  • 相同点:都实现WritablePartitionedPairCollection trait。它们内部都是用 Array (key0, value0, key1, value1, key2, value2…)实现 Map 逻辑。key 是 (分区ID,原key)

  • 不同点:PartitionedAppendOnlyMap 支持添加于更新 value:它使用map.changeValue((getPartition(kv._1), kv._1), update) 完成数据添加或者更新(聚合)。而PartitionedPairBuffer 仅支持添加

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined

// 选择是否 combine
if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
// 从 aggregator 中取出 createCombiner 和 mergeValue,制作成update函数
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
// 计数 + 1
addElementsRead()
kv = records.next()
// combine 模式使用 PartitionedAppendOnlyMap
map.changeValue((getPartition(kv._1), kv._1), update)
// 流程3: spill
maybeSpillCollection(usingMap = true)
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
// 非 combine 模式使用 PartitionedPairBuffer
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
// 流程3: spill
maybeSpillCollection(usingMap = false)
}
}
}

maybeSpill(…)

spill 的条件:内存申请没成功 或者 达到设定的阈值numElementsForceSpillThreshold

protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
// 元素个数是32的整数倍 且 大于 myMemoryThreshold
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
// 申请内存
val granted = acquireMemory(amountToRequest)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
shouldSpill = currentMemory >= myMemoryThreshold
}
// spill条件:上面的申请没成功 或者 达到阈值
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
// spill 在这里发生
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
}
shouldSpill
}

spill(…)

先排序,后spill。排序方面,由于分支不同,有两个排序逻辑

override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
// 排序
val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
// spill
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
spills += spillFile
}

PartitionedAppendOnlyMap 的排序逻辑:2重排序,先按分区ID排,再对分区内的数据排序(优先按 ordering 排序,否则hash)

注意这个hash排序,学过java中 == 和 equals 的区别的兄弟应该知道,hashcode 相等 是 两个对象 equals 的必要条件。这里只能保证 hashcode 相同的数据在一起,后续聚合时,还需经过比较后才能聚合(先打个预防针,后面源码会读到它)。

def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
new Comparator[(Int, K)] {
override def compare(a: (Int, K), b: (Int, K)): Int = {
val partitionDiff = a._1 - b._1
if (partitionDiff != 0) {
partitionDiff
} else {
keyComparator.compare(a._2, b._2)
}
}
}
}

// keyComparator:有 ordering 用 ordering,否则按 hash 排序。
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K): Int = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
}
})

PartitionedPairBuffer的排序逻辑:仅比较分区ID

override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
: Iterator[((Int, K), V)] = {
val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator)
iterator
}

/**
* A comparator for (Int, K) pairs that orders them by only their partition ID.
*/
def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
override def compare(a: (Int, K), b: (Int, K)): Int = {
a._1 - b._1
}
}

writePartitionedFile(…)

流程4: 合并spill文件和内存中未spill的文件,并返回分区长度数组

def writePartitionedFile(
blockId: BlockId,
outputFile: File): Array[Long] = {

// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)

// 首先知道:collection.destructiveSortedWritablePartitionedIterator(comparator) 用这玩意获取内存中的数据(未被spill),后面多次用到它
if (spills.isEmpty) {
// Case where we only have in-memory data
// 只有内存文件,刷进内存即可(当然排序什么的还是要的,和上一步一样的规则)
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
val partitionId = it.nextPartition()
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
val segment = writer.commitAndGet()
// 记录分区长度
lengths(partitionId) = segment.length
}
} else {
// 合并操作在这里:this.partitionedIterator,内部调用 merge()
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
val segment = writer.commitAndGet()
lengths(id) = segment.length
}
}
}

writer.close()
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

lengths
}

merge(…)

把 spill文件 和 内存文件 同分区的数据放到一起计算(3种计算情况)

private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
val readers = spills.map(new SpillReader(_))
val inMemBuffered = inMemory.buffered
(0 until numPartitions).iterator.map { p =>
// 很明显这兄弟函数式编程写的很6
// 主要就是把 spill文件 和 内存文件 同分区的数据放到一起计算(3种计算情况)。其实和以前的2重循环是一个意思
val inMemIterator = new IteratorForPartition(p, inMemBuffered)
val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
if (aggregator.isDefined) {
// Perform partial aggregation across partitions
// 聚合:(分有无 ordering 两种情况)
(p, mergeWithAggregation(
iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
} else if (ordering.isDefined) {
// No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
// sort the elements without trying to merge them
// 只排序:对它们进行归并排序。
// 说实话,我不觉得它会进这一分支。因为我多次强调过没有 aggregator 就必定没有 ordering
(p, mergeSort(iterators, ordering.get))
} else {
// 啥都不要,直接把同分区文件 flatten
(p, iterators.iterator.flatten)
}
}
}

mergeWithAggregation(…)

聚合:(分有无 ordering 两种情况)

private def mergeWithAggregation(
iterators: Seq[Iterator[Product2[K, C]]],
mergeCombiners: (C, C) => C,
comparator: Comparator[K],
totalOrder: Boolean)
: Iterator[Product2[K, C]] =
{
if (!totalOrder) {
// We only have a partial ordering, e.g. comparing the keys by hash code, which means that
// multiple distinct keys might be treated as equal by the ordering. To deal with this, we
// need to read all keys considered equal by the ordering at once and compare them.
// 无 ordering ,comparator 是 hash比较器。hash值相同的是key相同的必要条件
new Iterator[Iterator[Product2[K, C]]] {
val sorted = mergeSort(iterators, comparator).buffered

// Buffers reused across elements to decrease memory allocation
val keys = new ArrayBuffer[K]
val combiners = new ArrayBuffer[C]

override def hasNext: Boolean = sorted.hasNext

override def next(): Iterator[Product2[K, C]] = {
if (!hasNext) {
throw new NoSuchElementException
}
keys.clear()
combiners.clear()
val firstPair = sorted.next()
keys += firstPair._1
combiners += firstPair._2
val key = firstPair._1
// hash值相等
while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
val pair = sorted.next()
var i = 0
var foundKey = false
while (i < keys.size && !foundKey) {
// 注意 == 。 scala 就是用 == 比较对象相等的
if (keys(i) == pair._1) {
// key 相等 就合并
combiners(i) = mergeCombiners(combiners(i), pair._2)
foundKey = true
}
i += 1
}
if (!foundKey) {
keys += pair._1
combiners += pair._2
}
}

// Note that we return an iterator of elements since we could've had many keys marked
// equal by the partial order; we flatten this below to get a flat iterator of (K, C).
keys.iterator.zip(combiners.iterator)
}
}.flatMap(i => i)
} else {
// We have a total ordering, so the objects with the same key are sequential.
// 有 ordering:先归并排序,再把有相同的key的元素聚合就行了
new Iterator[Product2[K, C]] {
// 归并排序
val sorted = mergeSort(iterators, comparator).buffered

override def hasNext: Boolean = sorted.hasNext

override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val elem = sorted.next()
val k = elem._1
var c = elem._2
// key 相等 就合并
while (sorted.hasNext && sorted.head._1 == k) {
val pair = sorted.next()
c = mergeCombiners(c, pair._2)
}
(k, c)
}
}
}
}