0%

spark源码-shuffle之UnsafeShuffleWriter

spark 三大ShuffleWriter 之 UnsafeShuffleWriter

特点

  • 只适用不需要 map-side aggregation 的Shuffle操作

  • 它使用UnSafe API操作序列化数据,而不是Java对象,减少了内存占用及因此导致的GC耗时(参考Spark 内存管理之Tungsten),因此使用它时需要Serializer支持relocation

  • reduce端的分区数目小于等于 2^24 (因为排序过程中需要使用是数据指针,它记录了数据的地址和分区ID,其中分区ID占24位)

  • 溢写 & 合并时使用UnSafe API直接操作序列化数据,合并时不需要反序列化数据

  • 溢写 & 合并时可以使用fastMerge提升效率(调用NIO的transferTo方法)

大致流程

  1. 遍历数据,插入到ShuffleExternalSorter中。该过程完成许多事:将数据读入内存 同时 将数据指针不断写到指针数组中,当达到spill阈值,使用writeSortedFile()排序(排序排的是指针数组)并spill到磁盘。

  2. sorter.closeAndGetSpills():spill 内存中剩余的文件,返回所有 spill 文件的元信息(重点是每个分区的长度)

  3. 合并所有 spill 文件成一个大文件(有3种合并工具选择),并返回分区长度数组

  4. 根据分区长度数组生成索引文件

  5. 封装信息到MapStatus返回

总的来说就是一个 task 生成一个由 spill 文件合并形成的大文件(这玩意只是按分区排好,内部是无序的)和一个索引文件(上面流程是主要流程,重命名什么的就不写了)。
UnsafeShuffleWriter.jpg

源码

write(…)

这个write(...)方法就比较清爽了,因为步骤都封在方法里了

public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
boolean success = false;
try {
while (records.hasNext()) {
// 对应流程1
insertRecordIntoSorter(records.next());
}
// 剩余流程
closeAndWriteOutput();
success = true;
} finally {
if (sorter != null) {
try {
sorter.cleanupResources();
} catch (Exception e) {
// Only throw this error if we won't be masking another
// error.
if (success) {
throw e;
} else {
logger.error("In addition to a failure during writing, we failed during " +
"cleanup.", e);
}
}
}
}
}

insertRecordIntoSorter(…)

其核心方法是sorter.insertRecord()

void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
assert(sorter != null);
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serBuffer.reset();
// serOutputStream: 用于序列化对象的写入的流
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
serOutputStream.flush();

final int serializedRecordSize = serBuffer.size();
assert (serializedRecordSize > 0);

// 都是它干的
sorter.insertRecord(
serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
}

sorter.insertRecord(…)

将数据读入内存 同时 将数据指针不断写到指针数组中,当达到spill阈值,发生 spill

public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
throws IOException {

// for tests
assert(inMemSorter != null);
// 如果 Sorter 内的数据超过阈值,就发生 spill
if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
logger.info("Spilling data because number of spilledRecords crossed the threshold " +
numElementsForSpillThreshold);
spill();
}

// 检查 inMemSorter中 的数组是否满了,如果满了就扩容
growPointerArrayIfNecessary();
final int uaoSize = UnsafeAlignedOffset.getUaoSize();
// Need 4 or 8 bytes to store the record length.
final int required = length + uaoSize;
acquireNewPageIfNecessary(required);

assert(currentPage != null);
final Object base = currentPage.getBaseObject();
final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
UnsafeAlignedOffset.putSize(base, pageCursor, length);
pageCursor += uaoSize;

// 将数据写入 MemoryBlock 中
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;

// 将数据指针信息写入inMemSorter的数组中
inMemSorter.insertRecord(recordAddress, partitionId);
}

spill()

spill() 的核心是 writeSortedFile(false)

public long spill(long size, MemoryConsumer trigger) throws IOException {
if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) {
return 0L;
}

logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
spills.size(),
spills.size() > 1 ? " times" : " time");

writeSortedFile(false);
final long spillSize = freeMemory();
// spill 完整。重置 inMemSorter
inMemSorter.reset();
// Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
// records. Otherwise, if the task is over allocated memory, then without freeing the memory
// pages, we might not be able to get memory for the pointer array.
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
return spillSize;
}

writeSortedFile()

对内存中的数据进行排序(排序排的是指针数组)并 写到磁盘

private void writeSortedFile(boolean isLastFile) {

final ShuffleWriteMetrics writeMetricsToUse;

if (isLastFile) {
writeMetricsToUse = writeMetrics;
} else {
writeMetricsToUse = new ShuffleWriteMetrics();
}

// This call performs the actual sort.
// 迭代器:包含分区有序的数据指针(排序在这里进行,有两种排序手段)
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();
final byte[] writeBuffer = new byte[diskWriteBufferSize];

// 文件名:temp_shuffle_ + UUID
final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = spilledFileInfo._2();
final TempShuffleBlockId blockId = spilledFileInfo._1();
final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);


final SerializerInstance ser = DummySerializerInstance.INSTANCE;

final DiskBlockObjectWriter writer =
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);

int currentPartition = -1;
final int uaoSize = UnsafeAlignedOffset.getUaoSize();
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final int partition = sortedRecords.packedRecordPointer.getPartitionId();
assert (partition >= currentPartition);
if (partition != currentPartition) {
// Switch to the new partition
if (currentPartition != -1) {
final FileSegment fileSegment = writer.commitAndGet();
// 记录每个分区的长度
spillInfo.partitionLengths[currentPartition] = fileSegment.length();
}
currentPartition = partition;
}

final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
final Object recordPage = taskMemoryManager.getPage(recordPointer);
final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage);
long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length
while (dataRemaining > 0) {
final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining);
Platform.copyMemory(
recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
writer.write(writeBuffer, 0, toTransfer);
recordReadPosition += toTransfer;
dataRemaining -= toTransfer;
}
writer.recordWritten();
}

final FileSegment committedSegment = writer.commitAndGet();
writer.close();
if (currentPartition != -1) {
spillInfo.partitionLengths[currentPartition] = committedSegment.length();
spills.add(spillInfo);
}

if (!isLastFile) { // i.e. this is a spill file
writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten());
taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten());
}
}

closeAndWriteOutput();

终于完成了流程1: insertRecordIntoSorter

void closeAndWriteOutput() throws IOException {
assert(sorter != null);
updatePeakMemoryUsed();
serBuffer = null;
serOutputStream = null;
// 流程2: spill 内存中剩余的文件,返回所有 spill 文件的元信息(重点是每个分区的长度)
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
// 文件名:"shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = Utils.tempFileWith(output);
try {
try {
// 流程3: 合并 spill 文件,并返回 spill 文件的大小,用于计算索引文件
partitionLengths = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
}
}
}
// 流程4: 生成索引文件。这个步骤都是一样的,参考 BypassMergeSortShuffleWriter
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
流程5: 封装信息到`MapStatus`返回
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

mergeSpills(…)

合并 spill 文件,有3种合并手段:

  • 快合并:不使用压缩,或者特定的支持拼接的压缩格式:Snappy、LZF、LZ4、ZStd

    1. 当使用nio的transferTo传输 且 不需要加密时,使用 mergeSpillsWithTransferTo(spills, outputFile)

    2. 否则使用mergeSpillsWithFileStream(spills, outputFile, null)

  • 慢合并:

    1. 使用mergeSpillsWithFileStream(spills, outputFile, compressionCodec)
private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
// 压缩,以及压缩格式
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
// 支持快速合并的情况:不使用压缩,或者特定的支持拼接的压缩格式:Snappy、LZF、LZ4、ZStd
final boolean fastMergeIsSupported = !compressionEnabled ||
CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
// 是否加密
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
return new long[partitioner.numPartitions()];
} else if (spills.length == 1) {
Files.move(spills[0].file, outputFile);
return spills[0].partitionLengths;
} else {
final long[] partitionLengths;
if (fastMergeEnabled && fastMergeIsSupported) {
// 当使用nio的transferTo传输 且 不需要加密时,使用 transferTo-based fast merge
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
} else {
// 否则使用 fileStream-based fast merge
logger.debug("Using fileStream-based fast merge");
partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
}
} else {
logger.debug("Using slow merge");
// 使用 slow merge
partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
}
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
writeMetrics.incBytesWritten(outputFile.length());
return partitionLengths;
}
} catch (IOException e) {
if (outputFile.exists() && !outputFile.delete()) {
logger.error("Unable to delete output file {}", outputFile.getPath());
}
throw e;
}
}

mergeSpillsWithTransferTo(…)

由于内部流程都差不多,就举一个为例,核心是个2重循环,将所有spill文件中同一分区的数据合并,并按分区号排列

private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
assert (spills.length >= 2);
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
final long[] spillInputChannelPositions = new long[spills.length];
FileChannel mergedFileOutputChannel = null;

boolean threwException = true;
try {
// 对每个 spill 文件产出输入流
for (int i = 0; i < spills.length; i++) {
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
}
// 合并文件 的输出流
mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();

long bytesWrittenToMergedFile = 0;
// 2重循环,结果是将所有spill文件中同一分区的数据合并,并按分区号排列
for (int partition = 0; partition < numPartitions; partition++) {
for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
final FileChannel spillInputChannel = spillInputChannels[i];
final long writeStartTime = System.nanoTime();
// 将 spill 里的 指定分区数据 写入合并文件中
Utils.copyFileStreamNIO(
spillInputChannel, // spill 文件输入流
mergedFileOutputChannel, // 合并文件输出流
spillInputChannelPositions[i], // spill 中 该分区起始位置
partitionLengthInSpill); // spill 中 该分区长度
spillInputChannelPositions[i] += partitionLengthInSpill;
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
bytesWrittenToMergedFile += partitionLengthInSpill;
partitionLengths[partition] += partitionLengthInSpill; // 所有 spill 中该分区的总长度
}
}
if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
throw new IOException(
"Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
"position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
" version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
"unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
"to disable this NIO feature."
);
}
threwException = false;
} finally {
for (int i = 0; i < spills.length; i++) {
assert(spillInputChannelPositions[i] == spills[i].file.length());
Closeables.close(spillInputChannels[i], threwException);
}
Closeables.close(mergedFileOutputChannel, threwException);
}
return partitionLengths;
}