Skip to content

Commit 499632c

Browse files
Tartarus0zmguihuawen
authored andcommitted
[AURON apache#1498] Enable the JniBridge in the auron-core module (apache#1499)
* [AURON apache#1498] Enable the JniBridge in the auron-core module * fix * fix checkstyle * fix * fix * fix * fix * fix compile * opts NativeHelper * opts
1 parent 00f4e18 commit 499632c

18 files changed

Lines changed: 262 additions & 71 deletions

File tree

auron-core/src/main/java/org/apache/auron/jni/AuronAdaptor.java

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
*/
1717
package org.apache.auron.jni;
1818

19-
import java.io.File;
2019
import java.io.IOException;
2120
import java.nio.ByteBuffer;
2221
import org.apache.auron.configuration.AuronConfiguration;
@@ -84,29 +83,21 @@ public boolean isTaskRunning() {
8483
* @return Absolute path of the created temporary file.
8584
* @throws IOException If the temporary file cannot be created.
8685
*/
87-
public String getDirectWriteSpillToDiskFile() throws IOException {
88-
File tempFile = File.createTempFile("auron-spill-", ".tmp");
89-
tempFile.deleteOnExit();
90-
return tempFile.getAbsolutePath();
91-
}
86+
public abstract String getDirectWriteSpillToDiskFile() throws IOException;
9287

9388
/**
9489
* Retrieves the context classloader of the current thread.
9590
*
96-
* @return The context classloader of the current thread.
91+
* @return For Spark, return TaskContext of the current thread.
9792
*/
98-
public Object getThreadContext() {
99-
return Thread.currentThread().getContextClassLoader();
100-
}
93+
public abstract Object getThreadContext();
10194

10295
/**
103-
* Sets the context classloader for the current thread.
96+
* Sets the context for the current thread.
10497
*
105-
* @param context The classloader to set as the context classloader.
98+
* @param context For spark is TaskContext.
10699
*/
107-
public void setThreadContext(Object context) {
108-
Thread.currentThread().setContextClassLoader((ClassLoader) context);
109-
}
100+
public abstract void setThreadContext(Object context);
110101

111102
/**
112103
* Retrieves the on-heap spill manager implementation.

auron-core/src/main/java/org/apache/auron/jni/AuronCallNativeWrapper.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ public AuronCallNativeWrapper(
106106
* @throws RuntimeException If the native runtime encounters an error during batch processing.
107107
*/
108108
public boolean loadNextBatch(Consumer<VectorSchemaRoot> batchConsumer) {
109+
checkError();
109110
// load next batch
110111
try {
111112
this.batchConsumer = batchConsumer;
@@ -132,6 +133,10 @@ protected void importSchema(long ffiSchemaPtr) {
132133
}
133134
}
134135

136+
public Schema getArrowSchema() {
137+
return arrowSchema;
138+
}
139+
135140
protected void importBatch(long ffiArrayPtr) {
136141
if (nativeRuntimePtr == 0) {
137142
throw new RuntimeException("Native runtime is finalized");
@@ -172,15 +177,11 @@ protected byte[] getRawTaskDefinition() {
172177
return taskDefinition.toByteArray();
173178
}
174179

175-
private synchronized void close() {
180+
public synchronized void close() {
176181
if (nativeRuntimePtr != 0) {
177182
JniBridge.finalizeNative(nativeRuntimePtr);
178183
nativeRuntimePtr = 0;
179-
try {
180-
dictionaryProvider.close();
181-
} catch (Exception e) {
182-
LOG.error("Error closing dictionary provider", e);
183-
}
184+
dictionaryProvider.close();
184185
checkError();
185186
}
186187
}

auron-core/src/main/java/org/apache/auron/jni/JniBridge.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
*/
3636
@SuppressWarnings("unused")
3737
public class JniBridge {
38-
public static final ConcurrentHashMap<String, Object> resourcesMap = new ConcurrentHashMap<>();
38+
private static final ConcurrentHashMap<String, Object> resourcesMap = new ConcurrentHashMap<>();
3939

4040
private static final List<BufferPoolMXBean> directMXBeans =
4141
ManagementFactory.getPlatformMXBeans(BufferPoolMXBean.class);
@@ -60,6 +60,10 @@ public static Object getResource(String key) {
6060
return resourcesMap.remove(key);
6161
}
6262

63+
public static void putResource(String key, Object value) {
64+
resourcesMap.put(key, value);
65+
}
66+
6367
public static FSDataInputWrapper openFileAsDataInputWrapper(FileSystem fs, String path) throws Exception {
6468
// the path is a URI string, so we need to convert it to a URI object
6569
return FSDataInputWrapper.wrap(fs.open(new Path(new URI(path))));

auron-core/src/test/java/org/apache/auron/jni/MockAuronAdaptor.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.auron.jni;
1818

19+
import java.io.File;
20+
import java.io.IOException;
1921
import java.nio.ByteBuffer;
2022
import org.apache.auron.configuration.AuronConfiguration;
2123
import org.apache.auron.configuration.MockAuronConfiguration;
@@ -31,6 +33,23 @@ public void loadAuronLib() {
3133
// Mock implementation, no need to load auron library
3234
}
3335

36+
@Override
37+
public String getDirectWriteSpillToDiskFile() throws IOException {
38+
File tempFile = File.createTempFile("auron-spill-", ".tmp");
39+
tempFile.deleteOnExit();
40+
return tempFile.getAbsolutePath();
41+
}
42+
43+
@Override
44+
public Object getThreadContext() {
45+
return null;
46+
}
47+
48+
@Override
49+
public void setThreadContext(Object context) {
50+
// Mock implementation, no need to set thread context
51+
}
52+
3453
@Override
3554
public AuronConfiguration getAuronConfiguration() {
3655
return new MockAuronConfiguration();

native-engine/auron-jni-bridge/src/jni_bridge.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,10 @@ pub struct JniBridge<'a> {
555555
pub method_setContextClassLoader_ret: ReturnType,
556556
pub method_getResource: JStaticMethodID,
557557
pub method_getResource_ret: ReturnType,
558-
pub method_getTaskContext: JStaticMethodID,
559-
pub method_getTaskContext_ret: ReturnType,
558+
pub method_getThreadContext: JStaticMethodID,
559+
pub method_getThreadContext_ret: ReturnType,
560+
pub method_setThreadContext: JStaticMethodID,
561+
pub method_setThreadContext_ret: ReturnType,
560562
pub method_getTaskOnHeapSpillManager: JStaticMethodID,
561563
pub method_getTaskOnHeapSpillManager_ret: ReturnType,
562564
pub method_isTaskRunning: JStaticMethodID,
@@ -571,14 +573,12 @@ pub struct JniBridge<'a> {
571573
pub method_getTotalMemoryLimited_ret: ReturnType,
572574
pub method_getDirectWriteSpillToDiskFile: JStaticMethodID,
573575
pub method_getDirectWriteSpillToDiskFile_ret: ReturnType,
574-
pub method_initNativeThread: JStaticMethodID,
575-
pub method_initNativeThread_ret: ReturnType,
576576

577577
pub method_getAuronUDFWrapperContext: JStaticMethodID,
578578
pub method_getAuronUDFWrapperContext_ret: ReturnType,
579579
}
580580
impl<'a> JniBridge<'a> {
581-
pub const SIG_TYPE: &'static str = "org/apache/spark/sql/auron/JniBridge";
581+
pub const SIG_TYPE: &'static str = "org/apache/auron/jni/JniBridge";
582582

583583
pub fn new(env: &JNIEnv<'a>) -> JniResult<JniBridge<'a>> {
584584
let class = get_global_jclass(env, Self::SIG_TYPE)?;
@@ -602,12 +602,18 @@ impl<'a> JniBridge<'a> {
602602
"(Ljava/lang/String;)Ljava/lang/Object;",
603603
)?,
604604
method_getResource_ret: ReturnType::Object,
605-
method_getTaskContext: env.get_static_method_id(
605+
method_getThreadContext: env.get_static_method_id(
606606
class,
607-
"getTaskContext",
608-
"()Lorg/apache/spark/TaskContext;",
607+
"getThreadContext",
608+
"()Ljava/lang/Object;",
609609
)?,
610-
method_getTaskContext_ret: ReturnType::Object,
610+
method_getThreadContext_ret: ReturnType::Object,
611+
method_setThreadContext: env.get_static_method_id(
612+
class,
613+
"setThreadContext",
614+
"(Ljava/lang/Object;)V",
615+
)?,
616+
method_setThreadContext_ret: ReturnType::Primitive(Primitive::Void),
611617
method_getTaskOnHeapSpillManager: env.get_static_method_id(
612618
class,
613619
"getTaskOnHeapSpillManager",
@@ -646,12 +652,6 @@ impl<'a> JniBridge<'a> {
646652
"()Ljava/lang/String;",
647653
)?,
648654
method_getDirectWriteSpillToDiskFile_ret: ReturnType::Object,
649-
method_initNativeThread: env.get_static_method_id(
650-
class,
651-
"initNativeThread",
652-
"(Ljava/lang/ClassLoader;Lorg/apache/spark/TaskContext;)V",
653-
)?,
654-
method_initNativeThread_ret: ReturnType::Primitive(Primitive::Void),
655655

656656
method_getAuronUDFWrapperContext: env.get_static_method_id(
657657
class,
@@ -1408,7 +1408,7 @@ pub struct AuronCallNativeWrapper<'a> {
14081408
pub method_setError_ret: ReturnType,
14091409
}
14101410
impl<'a> AuronCallNativeWrapper<'a> {
1411-
pub const SIG_TYPE: &'static str = "org/apache/spark/sql/auron/AuronCallNativeWrapper";
1411+
pub const SIG_TYPE: &'static str = "org/apache/auron/jni/AuronCallNativeWrapper";
14121412

14131413
pub fn new(env: &JNIEnv<'a>) -> JniResult<AuronCallNativeWrapper<'a>> {
14141414
let class = get_global_jclass(env, Self::SIG_TYPE)?;

native-engine/auron/src/exec.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use crate::{handle_unwinded_scope, logging::init_logging, rt::NativeExecutionRun
3939

4040
#[allow(non_snake_case)]
4141
#[unsafe(no_mangle)]
42-
pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_callNative(
42+
pub extern "system" fn Java_org_apache_auron_jni_JniBridge_callNative(
4343
env: JNIEnv,
4444
_: JClass,
4545
executor_memory_overhead: i64,
@@ -113,7 +113,7 @@ pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_callNative(
113113

114114
#[allow(non_snake_case)]
115115
#[unsafe(no_mangle)]
116-
pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_nextBatch(
116+
pub extern "system" fn Java_org_apache_auron_jni_JniBridge_nextBatch(
117117
_: JNIEnv,
118118
_: JClass,
119119
raw_ptr: i64,
@@ -124,7 +124,7 @@ pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_nextBatch(
124124

125125
#[allow(non_snake_case)]
126126
#[unsafe(no_mangle)]
127-
pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_finalizeNative(
127+
pub extern "system" fn Java_org_apache_auron_jni_JniBridge_finalizeNative(
128128
_: JNIEnv,
129129
_: JClass,
130130
raw_ptr: i64,
@@ -135,7 +135,7 @@ pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_finalizeNative(
135135

136136
#[allow(non_snake_case)]
137137
#[unsafe(no_mangle)]
138-
pub extern "system" fn Java_org_apache_spark_sql_auron_JniBridge_onExit(_: JNIEnv, _: JClass) {
138+
pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(_: JNIEnv, _: JClass) {
139139
log::info!("exiting native environment");
140140
if MemManager::initialized() {
141141
MemManager::get().dump_status();

native-engine/auron/src/rt.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ use arrow::{
2626
};
2727
use auron_jni_bridge::{
2828
conf::{IntConf, SPARK_TASK_CPUS, TOKIO_WORKER_THREADS_PER_CPU},
29-
is_task_running,
30-
jni_bridge::JavaClasses,
31-
jni_call, jni_call_static, jni_convert_byte_array, jni_exception_check, jni_exception_occurred,
32-
jni_new_global_ref, jni_new_object, jni_new_string,
29+
is_task_running, jni_call, jni_call_static, jni_convert_byte_array, jni_exception_check,
30+
jni_exception_occurred, jni_new_global_ref, jni_new_object, jni_new_string,
3331
};
3432
use auron_serde::protobuf::TaskDefinition;
3533
use datafusion::{
@@ -105,17 +103,22 @@ impl NativeExecutionRuntime {
105103

106104
// create tokio runtime
107105
// propagate classloader and task context to spawned children threads
108-
let spark_task_context = jni_call_static!(JniBridge.getTaskContext() -> JObject)?;
109-
let spark_task_context_global = jni_new_global_ref!(spark_task_context.as_obj())?;
106+
let thread_context = jni_call_static!(JniBridge.getThreadContext() -> JObject)?;
107+
let thread_context_global = jni_new_global_ref!(thread_context.as_obj())?;
108+
// classloader
109+
let classloader = jni_call_static!(JniBridge.getContextClassLoader() -> JObject)?;
110+
let classloader_global = jni_new_global_ref!(classloader.as_obj())?;
110111
let mut tokio_runtime_builder = tokio::runtime::Builder::new_multi_thread();
111112
tokio_runtime_builder
112113
.thread_name(format!(
113114
"auron-native-stage-{stage_id}-part-{partition_id}-tid-{tid}"
114115
))
115116
.on_thread_start(move || {
116-
let classloader = JavaClasses::get().classloader;
117117
let _ = jni_call_static!(
118-
JniBridge.initNativeThread(classloader,spark_task_context_global.as_obj()) -> ()
118+
JniBridge.setContextClassLoader(classloader_global.as_obj()) -> ()
119+
);
120+
let _ = jni_call_static!(
121+
JniBridge.setThreadContext(thread_context_global.as_obj()) -> ()
119122
);
120123
THREAD_STAGE_ID.set(stage_id);
121124
THREAD_PARTITION_ID.set(partition_id);

spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ class ShimsImpl extends Shims with Logging {
685685

686686
// store fetch iterator in jni resource before native compute
687687
val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}"
688-
JniBridge.resourcesMap.put(
688+
org.apache.auron.jni.JniBridge.putResource(
689689
jniResourceId,
690690
() => {
691691
reader.asInstanceOf[AuronBlockStoreShuffleReaderBase[_, _]].readIpc()
@@ -778,7 +778,7 @@ class ShimsImpl extends Shims with Logging {
778778

779779
// store fetch iterator in jni resource before native compute
780780
val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}"
781-
JniBridge.resourcesMap.put(
781+
org.apache.auron.jni.JniBridge.putResource(
782782
jniResourceId,
783783
() => {
784784
reader.asInstanceOf[AuronBlockStoreShuffleReaderBase[_, _]].readIpc()
@@ -871,7 +871,7 @@ class ShimsImpl extends Shims with Logging {
871871

872872
// store fetch iterator in jni resource before native compute
873873
val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}"
874-
JniBridge.resourcesMap.put(
874+
org.apache.auron.jni.JniBridge.putResource(
875875
jniResourceId,
876876
() => {
877877
reader.asInstanceOf[AuronBlockStoreShuffleReaderBase[_, _]].readIpc()

0 commit comments

Comments
 (0)