Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions auron-core/src/main/java/org/apache/auron/jni/AuronAdaptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.auron.jni;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.auron.configuration.AuronConfiguration;
Expand Down Expand Up @@ -84,29 +83,21 @@ public boolean isTaskRunning() {
* @return Absolute path of the created temporary file.
* @throws IOException If the temporary file cannot be created.
*/
public String getDirectWriteSpillToDiskFile() throws IOException {
File tempFile = File.createTempFile("auron-spill-", ".tmp");
tempFile.deleteOnExit();
return tempFile.getAbsolutePath();
}
public abstract String getDirectWriteSpillToDiskFile() throws IOException;

/**
* Retrieves the context classloader of the current thread.
*
* @return The context classloader of the current thread.
* @return For Spark, return TaskContext of the current thread.
*/
public Object getThreadContext() {
return Thread.currentThread().getContextClassLoader();
}
public abstract Object getThreadContext();

/**
* Sets the context classloader for the current thread.
* Sets the context for the current thread.
*
* @param context The classloader to set as the context classloader.
* @param context For spark is TaskContext.
*/
public void setThreadContext(Object context) {
Thread.currentThread().setContextClassLoader((ClassLoader) context);
}
public abstract void setThreadContext(Object context);

/**
* Retrieves the on-heap spill manager implementation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public AuronCallNativeWrapper(
* @throws RuntimeException If the native runtime encounters an error during batch processing.
*/
public boolean loadNextBatch(Consumer<VectorSchemaRoot> batchConsumer) {
checkError();
// load next batch
try {
this.batchConsumer = batchConsumer;
Expand All @@ -132,6 +133,10 @@ protected void importSchema(long ffiSchemaPtr) {
}
}

public Schema getArrowSchema() {
return arrowSchema;
}

protected void importBatch(long ffiArrayPtr) {
if (nativeRuntimePtr == 0) {
throw new RuntimeException("Native runtime is finalized");
Expand Down Expand Up @@ -172,15 +177,11 @@ protected byte[] getRawTaskDefinition() {
return taskDefinition.toByteArray();
}

private synchronized void close() {
public synchronized void close() {
if (nativeRuntimePtr != 0) {
JniBridge.finalizeNative(nativeRuntimePtr);
nativeRuntimePtr = 0;
try {
dictionaryProvider.close();
} catch (Exception e) {
LOG.error("Error closing dictionary provider", e);
}
dictionaryProvider.close();
checkError();
}
}
Expand Down
6 changes: 5 additions & 1 deletion auron-core/src/main/java/org/apache/auron/jni/JniBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
*/
@SuppressWarnings("unused")
public class JniBridge {
public static final ConcurrentHashMap<String, Object> resourcesMap = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String, Object> resourcesMap = new ConcurrentHashMap<>();

private static final List<BufferPoolMXBean> directMXBeans =
ManagementFactory.getPlatformMXBeans(BufferPoolMXBean.class);
Expand All @@ -60,6 +60,10 @@ public static Object getResource(String key) {
return resourcesMap.remove(key);
}

public static void putResource(String key, Object value) {
resourcesMap.put(key, value);
}

public static FSDataInputWrapper openFileAsDataInputWrapper(FileSystem fs, String path) throws Exception {
// the path is a URI string, so we need to convert it to a URI object
return FSDataInputWrapper.wrap(fs.open(new Path(new URI(path))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.auron.jni;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.auron.configuration.AuronConfiguration;
import org.apache.auron.configuration.MockAuronConfiguration;
Expand All @@ -31,6 +33,23 @@ public void loadAuronLib() {
// Mock implementation, no need to load auron library
}

@Override
public String getDirectWriteSpillToDiskFile() throws IOException {
File tempFile = File.createTempFile("auron-spill-", ".tmp");
tempFile.deleteOnExit();
return tempFile.getAbsolutePath();
}

@Override
public Object getThreadContext() {
return null;
}

@Override
public void setThreadContext(Object context) {
// Mock implementation, no need to set thread context
}

@Override
public AuronConfiguration getAuronConfiguration() {
return new MockAuronConfiguration();
Expand Down
32 changes: 16 additions & 16 deletions native-engine/auron-jni-bridge/src/jni_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,10 @@ pub struct JniBridge<'a> {
pub method_setContextClassLoader_ret: ReturnType,
pub method_getResource: JStaticMethodID,
pub method_getResource_ret: ReturnType,
pub method_getTaskContext: JStaticMethodID,
pub method_getTaskContext_ret: ReturnType,
pub method_getThreadContext: JStaticMethodID,
pub method_getThreadContext_ret: ReturnType,
pub method_setThreadContext: JStaticMethodID,
pub method_setThreadContext_ret: ReturnType,
pub method_getTaskOnHeapSpillManager: JStaticMethodID,
pub method_getTaskOnHeapSpillManager_ret: ReturnType,
pub method_isTaskRunning: JStaticMethodID,
Expand All @@ -571,14 +573,12 @@ pub struct JniBridge<'a> {
pub method_getTotalMemoryLimited_ret: ReturnType,
pub method_getDirectWriteSpillToDiskFile: JStaticMethodID,
pub method_getDirectWriteSpillToDiskFile_ret: ReturnType,
pub method_initNativeThread: JStaticMethodID,
pub method_initNativeThread_ret: ReturnType,

pub method_getAuronUDFWrapperContext: JStaticMethodID,
pub method_getAuronUDFWrapperContext_ret: ReturnType,
}
impl<'a> JniBridge<'a> {
pub const SIG_TYPE: &'static str = "org/apache/spark/sql/auron/JniBridge";
pub const SIG_TYPE: &'static str = "org/apache/auron/jni/JniBridge";

pub fn new(env: &JNIEnv<'a>) -> JniResult<JniBridge<'a>> {
let class = get_global_jclass(env, Self::SIG_TYPE)?;
Expand All @@ -602,12 +602,18 @@ impl<'a> JniBridge<'a> {
"(Ljava/lang/String;)Ljava/lang/Object;",
)?,
method_getResource_ret: ReturnType::Object,
method_getTaskContext: env.get_static_method_id(
method_getThreadContext: env.get_static_method_id(
class,
"getTaskContext",
"()Lorg/apache/spark/TaskContext;",
"getThreadContext",
"()Ljava/lang/Object;",
)?,
method_getTaskContext_ret: ReturnType::Object,
method_getThreadContext_ret: ReturnType::Object,
method_setThreadContext: env.get_static_method_id(
class,
"setThreadContext",
"(Ljava/lang/Object;)V",
)?,
method_setThreadContext_ret: ReturnType::Primitive(Primitive::Void),
method_getTaskOnHeapSpillManager: env.get_static_method_id(
class,
"getTaskOnHeapSpillManager",
Expand Down Expand Up @@ -646,12 +652,6 @@ impl<'a> JniBridge<'a> {
"()Ljava/lang/String;",
)?,
method_getDirectWriteSpillToDiskFile_ret: ReturnType::Object,
method_initNativeThread: env.get_static_method_id(
class,
"initNativeThread",
"(Ljava/lang/ClassLoader;Lorg/apache/spark/TaskContext;)V",
)?,
method_initNativeThread_ret: ReturnType::Primitive(Primitive::Void),

method_getAuronUDFWrapperContext: env.get_static_method_id(
class,
Expand Down Expand Up @@ -1408,7 +1408,7 @@ pub struct AuronCallNativeWrapper<'a> {
pub method_setError_ret: ReturnType,
}
impl<'a> AuronCallNativeWrapper<'a> {
pub const SIG_TYPE: &'static str = "org/apache/spark/sql/auron/AuronCallNativeWrapper";
pub const SIG_TYPE: &'static str = "org/apache/auron/jni/AuronCallNativeWrapper";

pub fn new(env: &JNIEnv<'a>) -> JniResult<AuronCallNativeWrapper<'a>> {
let class = get_global_jclass(env, Self::SIG_TYPE)?;
Expand Down
8 changes: 4 additions & 4 deletions native-engine/auron/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::{handle_unwinded_scope, logging::init_logging, rt::NativeExecutionRun

#[allow(non_snake_case)]
#[unsafe(no_mangle)]
pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_callNative(
pub extern "system" fn Java_org_apache_auron_jni_JniBridge_callNative(
env: JNIEnv,
_: JClass,
executor_memory_overhead: i64,
Expand Down Expand Up @@ -113,7 +113,7 @@ pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_callNative(

#[allow(non_snake_case)]
#[unsafe(no_mangle)]
pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_nextBatch(
pub extern "system" fn Java_org_apache_auron_jni_JniBridge_nextBatch(
_: JNIEnv,
_: JClass,
raw_ptr: i64,
Expand All @@ -124,7 +124,7 @@ pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_nextBatch(

#[allow(non_snake_case)]
#[unsafe(no_mangle)]
pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_finalizeNative(
pub extern "system" fn Java_org_apache_auron_jni_JniBridge_finalizeNative(
_: JNIEnv,
_: JClass,
raw_ptr: i64,
Expand All @@ -135,7 +135,7 @@ pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_finalizeNative(

#[allow(non_snake_case)]
#[unsafe(no_mangle)]
pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_onExit(_: JNIEnv, _: JClass) {
pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(_: JNIEnv, _: JClass) {
log::info!("exiting native environment");
if MemManager::initialized() {
MemManager::get().dump_status();
Expand Down
19 changes: 11 additions & 8 deletions native-engine/auron/src/rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ use arrow::{
};
use auron_jni_bridge::{
conf::{IntConf, SPARK_TASK_CPUS, TOKIO_WORKER_THREADS_PER_CPU},
is_task_running,
jni_bridge::JavaClasses,
jni_call, jni_call_static, jni_convert_byte_array, jni_exception_check, jni_exception_occurred,
jni_new_global_ref, jni_new_object, jni_new_string,
is_task_running, jni_call, jni_call_static, jni_convert_byte_array, jni_exception_check,
jni_exception_occurred, jni_new_global_ref, jni_new_object, jni_new_string,
};
use auron_serde::protobuf::TaskDefinition;
use datafusion::{
Expand Down Expand Up @@ -105,17 +103,22 @@ impl NativeExecutionRuntime {

// create tokio runtime
// propagate classloader and task context to spawned children threads
let spark_task_context = jni_call_static!(JniBridge.getTaskContext() -> JObject)?;
let spark_task_context_global = jni_new_global_ref!(spark_task_context.as_obj())?;
let thread_context = jni_call_static!(JniBridge.getThreadContext() -> JObject)?;
let thread_context_global = jni_new_global_ref!(thread_context.as_obj())?;
// classloader
let classloader = jni_call_static!(JniBridge.getContextClassLoader() -> JObject)?;
let classloader_global = jni_new_global_ref!(classloader.as_obj())?;
let mut tokio_runtime_builder = tokio::runtime::Builder::new_multi_thread();
tokio_runtime_builder
.thread_name(format!(
"auron-native-stage-{stage_id}-part-{partition_id}-tid-{tid}"
))
.on_thread_start(move || {
let classloader = JavaClasses::get().classloader;
let _ = jni_call_static!(
JniBridge.initNativeThread(classloader,spark_task_context_global.as_obj()) -> ()
JniBridge.setContextClassLoader(classloader_global.as_obj()) -> ()
);
let _ = jni_call_static!(
JniBridge.setThreadContext(thread_context_global.as_obj()) -> ()
);
THREAD_STAGE_ID.set(stage_id);
THREAD_PARTITION_ID.set(partition_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ class ShimsImpl extends Shims with Logging {

// store fetch iterator in jni resource before native compute
val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}"
JniBridge.resourcesMap.put(
org.apache.auron.jni.JniBridge.putResource(
jniResourceId,
() => {
reader.asInstanceOf[AuronBlockStoreShuffleReaderBase[_, _]].readIpc()
Expand Down Expand Up @@ -778,7 +778,7 @@ class ShimsImpl extends Shims with Logging {

// store fetch iterator in jni resource before native compute
val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}"
JniBridge.resourcesMap.put(
org.apache.auron.jni.JniBridge.putResource(
jniResourceId,
() => {
reader.asInstanceOf[AuronBlockStoreShuffleReaderBase[_, _]].readIpc()
Expand Down Expand Up @@ -871,7 +871,7 @@ class ShimsImpl extends Shims with Logging {

// store fetch iterator in jni resource before native compute
val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}"
JniBridge.resourcesMap.put(
org.apache.auron.jni.JniBridge.putResource(
jniResourceId,
() => {
reader.asInstanceOf[AuronBlockStoreShuffleReaderBase[_, _]].readIpc()
Expand Down
Loading
Loading