Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions native-engine/blaze-jni-bridge/src/jni_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,8 @@ pub struct BlazeBlockObject<'a> {
pub method_getByteBuffer_ret: ReturnType,
pub method_getChannel: JMethodID,
pub method_getChannel_ret: ReturnType,
pub method_throwFetchFailed: JMethodID,
pub method_throwFetchFailed_ret: ReturnType,
}

impl<'a> BlazeBlockObject<'a> {
Expand Down Expand Up @@ -1542,6 +1544,12 @@ impl<'a> BlazeBlockObject<'a> {
"()Ljava/nio/channels/ReadableByteChannel;",
)?,
method_getChannel_ret: ReturnType::Object,
method_throwFetchFailed: env.get_method_id(
class,
"throwFetchFailed",
"(Ljava/lang/String;)V",
)?,
method_throwFetchFailed_ret: ReturnType::Primitive(Primitive::Void),
})
}
}
Expand Down
15 changes: 13 additions & 2 deletions native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ fn read_ipc(
.expect("tokio spawn_blocking error")?
} {
// get ipc reader
let mut reader = tokio::task::spawn_blocking(move || {
let block_cloned = block.clone();
let mut reader = tokio::task::spawn_blocking(|| {
let block = block_cloned;
if jni_call!(BlazeBlockObject(block.as_obj()).hasFileSegment() -> bool)? {
return get_file_reader(block.as_obj());
}
Expand All @@ -201,7 +203,16 @@ fn read_ipc(
.await
.expect("tokio spawn_blocking error")?;

while let Some((num_rows, cols)) = reader.read_batch(&exec_ctx.output_schema())? {
while let Some((num_rows, cols)) =
reader.read_batch(&exec_ctx.output_schema()).or_else(|e| {
// throw FetchFailureException
let block = block.clone();
let errmsg = jni_new_string!(e.message().as_ref())?;
jni_call!(BlazeBlockObject(block.as_obj())
.throwFetchFailed(errmsg.as_obj()) -> ())?; // always return error
Ok::<_, DataFusionError>(None)
})?
{
let (cur_staging_num_rows, cur_staging_mem_size) = {
let staging_cols_cloned = staging_cols.clone();
let mut staging_cols = staging_cols_cloned.lock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
*/
package org.apache.spark.sql.execution.blaze.shuffle

import java.io.{FileInputStream, InputStream}
import java.io.{FileInputStream, IOException, InputStream}
import java.nio.ByteBuffer
import java.nio.channels.{Channels, ReadableByteChannel}
import scala.annotation.tailrec

import org.apache.commons.lang3.reflect.{FieldUtils, MethodUtils}
import org.apache.spark.{InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -59,6 +57,9 @@ object BlazeBlockStoreShuffleReaderBase extends Logging {
override def getFileOffset: Long = offset
override def getFileLength: Long = limit
override def close(): Unit = in.close()
override def throwFetchFailed(errmsg: String): Unit = {
throwFetchFailedOnInputStream(in, errmsg)
}
}
case None =>
}
Expand All @@ -69,6 +70,9 @@ object BlazeBlockStoreShuffleReaderBase extends Logging {
override def hasByteBuffer: Boolean = true
override def getByteBuffer: ByteBuffer = buf
override def close(): Unit = in.close()
override def throwFetchFailed(errmsg: String): Unit = {
throwFetchFailedOnInputStream(in, errmsg)
}
}
case None =>
}
Expand All @@ -77,18 +81,55 @@ object BlazeBlockStoreShuffleReaderBase extends Logging {
new BlockObject {
override def getChannel: ReadableByteChannel = channel
override def close(): Unit = channel.close()
override def throwFetchFailed(errmsg: String): Unit = {
throwFetchFailedOnInputStream(in, errmsg)
}
}
}

@tailrec
private def unwrapInputStream(in: InputStream): InputStream = {
val bufferReleasingInputStreamClsName = "org.apache.spark.storage.BufferReleasingInputStream"
in match {
case in if bufferReleasingInputStreamClsName.endsWith(in.getClass.getName) =>
val inner = MethodUtils.invokeMethod(in, true, "delegate").asInstanceOf[InputStream]
unwrapInputStream(inner)
case in => in
val bufferReleasingInputStreamCls =
Class.forName("org.apache.spark.storage.BufferReleasingInputStream")
if (in.getClass != bufferReleasingInputStreamCls) {
return in
}

try {
return MethodUtils.invokeMethod(in, true, "delegate").asInstanceOf[InputStream]
} catch {
case _: ReflectiveOperationException => // passthrough
}

try {
val fallbackMethodName = "org$apache$spark$storage$BufferReleasingInputStream$$delegate"
return MethodUtils.invokeMethod(in, true, fallbackMethodName).asInstanceOf[InputStream]
} catch {
case _: ReflectiveOperationException => // passthrough
}
throw new RuntimeException("cannot unwrap BufferReleasingInputStream")
Comment on lines +94 to +109
Copy link

Copilot AI Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider removing explicit 'return' statements in unwrapInputStream to adopt a more idiomatic Scala style and improve code clarity.

Suggested change
return in
}
try {
return MethodUtils.invokeMethod(in, true, "delegate").asInstanceOf[InputStream]
} catch {
case _: ReflectiveOperationException => // passthrough
}
try {
val fallbackMethodName = "org$apache$spark$storage$BufferReleasingInputStream$$delegate"
return MethodUtils.invokeMethod(in, true, fallbackMethodName).asInstanceOf[InputStream]
} catch {
case _: ReflectiveOperationException => // passthrough
}
throw new RuntimeException("cannot unwrap BufferReleasingInputStream")
in
} else {
try {
MethodUtils.invokeMethod(in, true, "delegate").asInstanceOf[InputStream]
} catch {
case _: ReflectiveOperationException =>
try {
val fallbackMethodName = "org$apache$spark$storage$BufferReleasingInputStream$$delegate"
MethodUtils.invokeMethod(in, true, fallbackMethodName).asInstanceOf[InputStream]
} catch {
case _: ReflectiveOperationException =>
throw new RuntimeException("cannot unwrap BufferReleasingInputStream")
}
}
}

Copilot uses AI. Check for mistakes.
}

private def throwFetchFailedOnInputStream(in: InputStream, errmsg: String): Unit = {
// for spark 3.x
try {
val throwFunction: () => Object = () => throw new IOException(errmsg)
MethodUtils.invokeMethod(in, true, "tryOrFetchFailedException", throwFunction)
return
} catch {
case _: ReflectiveOperationException => // passthrough
}

// for spark 2.x
try {
val throwFunction: () => Object = () => throw new IOException("Stream is corrupted")
MethodUtils.invokeMethod(in, true, "wrapFetchFailedError", throwFunction)
return
} catch {
case _: ReflectiveOperationException => // passthrough
}

// fallback
throw new IOException(errmsg)
}

def getFileSegmentFromInputStream(in: InputStream): Option[(String, Long, Long)] = {
Expand Down Expand Up @@ -137,4 +178,5 @@ trait BlockObject extends AutoCloseable {
def getFileLength: Long = throw new UnsupportedOperationException
def getByteBuffer: ByteBuffer = throw new UnsupportedOperationException
def getChannel: ReadableByteChannel = throw new UnsupportedOperationException
def throwFetchFailed(errmsg: String): Unit = throw new UnsupportedOperationException
}