Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions native-engine/datafusion-ext-plans/src/common/ipc_compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ use std::io::{BufReader, Read, Take, Write};

use arrow::{array::ArrayRef, datatypes::SchemaRef};
use blaze_jni_bridge::{conf, conf::StringConf, is_jni_bridge_inited};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use datafusion::common::Result;
use datafusion_ext_commons::{
df_execution_err,
io::{read_len, read_one_batch, write_len, write_one_batch},
io::{read_one_batch, write_one_batch},
};
use once_cell::sync::OnceCell;

Expand All @@ -40,6 +41,8 @@ unsafe impl<W: Write> Send for IpcCompressionWriter<W> {}
impl<W: Write> IpcCompressionWriter<W> {
pub fn new(output: W) -> Self {
let mut shared_buf = VecBuffer::default();
shared_buf.inner_mut().extend_from_slice(&[0u8; 4]);

let block_writer = IoCompressionWriter::new_with_configured_codec(shared_buf.writer());
Self {
output,
Expand All @@ -49,8 +52,18 @@ impl<W: Write> IpcCompressionWriter<W> {
}
}

/// Write a batch, returning uncompressed bytes size
pub fn set_output(&mut self, output: W) {
assert!(
self.block_empty,
"IpcCompressionWriter must be empty while changing output"
);
self.output = output;
}

pub fn write_batch(&mut self, num_rows: usize, cols: &[ArrayRef]) -> Result<()> {
if num_rows == 0 {
return Ok(());
}
write_one_batch(num_rows, cols, &mut self.block_writer)?;
self.block_empty = false;

Expand All @@ -67,11 +80,15 @@ impl<W: Write> IpcCompressionWriter<W> {
self.block_writer.finish()?;

// write
write_len(self.shared_buf.inner().len(), &mut self.output)?;
let block_len = self.shared_buf.inner().len() - 4;
self.shared_buf.inner_mut()[0..4]
.as_mut()
.write_u32::<LittleEndian>(block_len as u32)?;
self.output.write_all(self.shared_buf.inner())?;
self.shared_buf.inner_mut().clear();

// open next buf
self.shared_buf.inner_mut().clear();
self.shared_buf.inner_mut().extend_from_slice(&[0u8; 4]);
self.block_writer =
IoCompressionWriter::new_with_configured_codec(self.shared_buf.writer());
self.block_empty = true;
Expand Down Expand Up @@ -114,7 +131,7 @@ impl<R: Read> IpcCompressionReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match std::mem::take(&mut self.0.input) {
InputState::BlockStart(mut input) => {
let block_len = match read_len(&mut input) {
let block_len = match input.read_u32::<LittleEndian>() {
Ok(block_len) => block_len,
Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
return Ok(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ impl BufferedData {
return Ok(());
}
let mut iter = self.into_sorted_batches(partitioning)?;
let mut writer = IpcCompressionWriter::new(RssWriter::new(rss_partition_writer.clone(), 0));

while (iter.cur_part_id() as usize) < partitioning.partition_count() {
let cur_part_id = iter.cur_part_id();
let mut writer = IpcCompressionWriter::new(RssWriter::new(
writer.set_output(RssWriter::new(
rss_partition_writer.clone(),
cur_part_id as usize,
));
Expand Down
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.0.3</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
</properties>
</profile>

Expand All @@ -291,6 +292,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.1.3</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
</properties>
</profile>

Expand All @@ -305,6 +307,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.2.4</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
</properties>
</profile>

Expand All @@ -319,6 +322,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.3.4</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
</properties>
</profile>

Expand All @@ -333,6 +337,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.4.3</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
</properties>
</profile>

Expand All @@ -347,6 +352,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.5.3</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
</properties>
</profile>
</profiles>
Expand Down
6 changes: 6 additions & 0 deletions spark-extension-shims-spark3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
<artifactId>spark-sql_${scalaVersion}</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.celeborn</groupId>
<artifactId>celeborn-client-spark-3-shaded_${scalaVersion}</artifactId>
<version>${celebornVersion}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-c-data</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ import org.apache.spark.sql.execution.blaze.plan.NativeUnionExec
import org.apache.spark.sql.execution.blaze.plan.NativeWindowBase
import org.apache.spark.sql.execution.blaze.plan.NativeWindowExec
import org.apache.spark.sql.execution.blaze.plan._
import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReader
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager
import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec
Expand Down Expand Up @@ -427,6 +428,20 @@ class ShimsImpl extends Shims with Logging {
override def getShuffleWriteExec(
input: pb.PhysicalPlanNode,
nativeOutputPartitioning: pb.PhysicalHashRepartition.Builder): pb.PhysicalPlanNode = {

if (SparkEnv.get.shuffleManager.isInstanceOf[BlazeCelebornShuffleManager]) {
return pb.PhysicalPlanNode
.newBuilder()
.setRssShuffleWriter(
pb.RssShuffleWriterExecNode
.newBuilder()
.setInput(input)
.setOutputPartitioning(nativeOutputPartitioning)
.buildPartial()
) // shuffleId is not set at the moment, will be set in ShuffleWriteProcessor
.build()
}

pb.PhysicalPlanNode
.newBuilder()
.setShuffleWriter(
Expand Down Expand Up @@ -604,7 +619,7 @@ class ShimsImpl extends Shims with Logging {
JniBridge.resourcesMap.put(
jniResourceId,
() => {
reader.asInstanceOf[BlazeBlockStoreShuffleReader[_, _]].readIpc()
reader.asInstanceOf[BlazeBlockStoreShuffleReaderBase[_, _]].readIpc()
})

pb.PhysicalPlanNode
Expand Down Expand Up @@ -762,7 +777,7 @@ class ShimsImpl extends Shims with Logging {
JniBridge.resourcesMap.put(
jniResourceId,
() => {
reader.asInstanceOf[BlazeBlockStoreShuffleReader[_, _]].readIpc()
reader.asInstanceOf[BlazeBlockStoreShuffleReaderBase[_, _]].readIpc()
})

pb.PhysicalPlanNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleWriter
import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager
import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleWriter
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
Expand Down Expand Up @@ -133,6 +135,19 @@ case class NativeShuffleExchangeExec(
mapId,
context,
createMetricsReporter(context))

if (SparkEnv.get.shuffleManager.isInstanceOf[BlazeCelebornShuffleManager]) {
return writer
.asInstanceOf[BlazeCelebornShuffleWriter[_, _]]
.nativeRssShuffleWrite(
rdd.asInstanceOf[MapPartitionsRDD[_, _]].prev.asInstanceOf[NativeRDD],
dep,
mapId.toInt,
context,
partition,
numPartitions)
}

writer
.asInstanceOf[BlazeShuffleWriter[_, _]]
.nativeShuffleWrite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArr

import com.thoughtworks.enableIf

abstract class BlazeRssShuffleManagerBase(conf: SparkConf) extends ShuffleManager with Logging {
abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManager with Logging {
override def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright 2022 The Blaze Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.blaze.shuffle.celeborn

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.commons.lang3.reflect.FieldUtils
import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleManagerBase
import org.apache.spark.SparkConf
import org.apache.spark.TaskContext
import org.apache.spark.shuffle.ShuffleBlockResolver
import org.apache.spark.shuffle.ShuffleReader
import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleWriterBase
import org.apache.spark.ShuffleDependency
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.shuffle.celeborn.SparkShuffleManager
import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleReaderBase
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
import org.apache.spark.shuffle.celeborn.ExecutorShuffleIdTracker

class BlazeCelebornShuffleManager(conf: SparkConf, isDriver: Boolean)
extends BlazeRssShuffleManagerBase(conf) {
private val celebornShuffleManager: SparkShuffleManager =
new SparkShuffleManager(conf, isDriver)

override def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
celebornShuffleManager.registerShuffle(shuffleId, dependency)
}

override def unregisterShuffle(shuffleId: Int): Boolean = {
celebornShuffleManager.unregisterShuffle(shuffleId)
}

override def getBlazeRssShuffleReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): BlazeRssShuffleReaderBase[K, C] = {
this.getBlazeRssShuffleReader(
handle,
0,
Int.MaxValue,
startPartition,
endPartition,
context,
metrics)
}

override def getBlazeRssShuffleReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): BlazeRssShuffleReaderBase[K, C] = {

val celebornHandle = handle.asInstanceOf[CelebornShuffleHandle[_, _, _]]
val celebornConf = FieldUtils
.readField(celebornShuffleManager, "celebornConf", true)
.asInstanceOf[CelebornConf]
val shuffleIdTracker = FieldUtils
.readField(celebornShuffleManager, "shuffleIdTracker", true)
.asInstanceOf[ExecutorShuffleIdTracker]
val reader = new BlazeCelebornShuffleReader(
celebornConf,
celebornHandle,
startPartition,
endPartition,
startMapIndex = Some(startMapIndex),
endMapIndex = Some(endMapIndex),
context,
metrics,
shuffleIdTracker)
reader.asInstanceOf[BlazeRssShuffleReaderBase[K, C]]
}

override def getRssShuffleReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
celebornShuffleManager.getReader(handle, startPartition, endPartition, context, metrics)
}

override def getRssShuffleReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
celebornShuffleManager.getReaderForRange(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics)
}

override def getBlazeRssShuffleWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): BlazeRssShuffleWriterBase[K, V] = {

// ensure celeborn client is initialized
assert(celebornShuffleManager.getWriter(handle, mapId, context, metrics) != null)
val shuffleClient = FieldUtils
.readField(celebornShuffleManager, "shuffleClient", true)
.asInstanceOf[ShuffleClient]

val celebornHandle = handle.asInstanceOf[CelebornShuffleHandle[_, _, _]]
val writer = new BlazeCelebornShuffleWriter(shuffleClient, context, celebornHandle, metrics)
writer.asInstanceOf[BlazeRssShuffleWriterBase[K, V]]
}

override def getRssShuffleWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
celebornShuffleManager.getWriter(handle, mapId, context, metrics)
}

override def shuffleBlockResolver: ShuffleBlockResolver =
celebornShuffleManager.shuffleBlockResolver()

override def stop(): Unit =
celebornShuffleManager.stop()
}

object BlazeCelebornShuffleManager {
def getEncodedAttemptNumber(context: TaskContext): Int =
(context.stageAttemptNumber << 16) | context.attemptNumber
}
Loading