Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,9 @@ public OnHeapSpillManager getOnHeapSpillManager() {
* @throws UnsupportedOperationException If the method is not implemented.
*/
public abstract AuronUDFWrapperContext getAuronUDFWrapperContext(ByteBuffer udfSerialized);

/**
* Returns the name of the current engine, such as Spark or Flink.
*/
public abstract String getEngineName();
}
4 changes: 4 additions & 0 deletions auron-core/src/main/java/org/apache/auron/jni/JniBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ public static String stringConf(String confKey) {
return getConfValue(confKey);
}

public static String getEngineName() {
return AuronAdaptor.getInstance().getEngineName();
}

static <T> T getConfValue(String confKey) {
Class<? extends AuronConfiguration> confClass =
AuronAdaptor.getInstance().getAuronConfiguration().getClass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,9 @@ public AuronConfiguration getAuronConfiguration() {
public AuronUDFWrapperContext getAuronUDFWrapperContext(ByteBuffer udfSerialized) {
return new MockAuronUDFWrapperContext(udfSerialized);
}

@Override
public String getEngineName() {
return "Test";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,9 @@ public AuronConfiguration getAuronConfiguration() {
public AuronUDFWrapperContext getAuronUDFWrapperContext(ByteBuffer byteBuffer) {
throw new UnsupportedOperationException();
}

@Override
public String getEngineName() {
return "Flink";
}
}
153 changes: 130 additions & 23 deletions native-engine/auron-jni-bridge/src/jni_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,21 +442,21 @@ pub struct JavaClasses<'a> {

pub cSparkFileSegment: SparkFileSegment<'a>,
pub cSparkSQLMetric: SparkSQLMetric<'a>,
pub cSparkMetricNode: SparkMetricNode<'a>,
pub cSparkAuronUDFWrapperContext: SparkAuronUDFWrapperContext<'a>,
pub cSparkUDAFWrapperContext: SparkUDAFWrapperContext<'a>,
pub cSparkUDTFWrapperContext: SparkUDTFWrapperContext<'a>,
pub cSparkUDAFMemTracker: SparkUDAFMemTracker<'a>,
pub cAuronRssPartitionWriterBase: AuronRssPartitionWriterBase<'a>,
pub cAuronCallNativeWrapper: AuronCallNativeWrapper<'a>,
pub cAuronOnHeapSpillManager: AuronOnHeapSpillManager<'a>,
pub cAuronNativeParquetSinkUtils: AuronNativeParquetSinkUtils<'a>,
pub cAuronBlockObject: AuronBlockObject<'a>,
pub cAuronJsonFallbackWrapper: AuronJsonFallbackWrapper<'a>,

pub cAuronArrowFFIExporter: AuronArrowFFIExporter<'a>,
pub cAuronCallNativeWrapper: AuronCallNativeWrapper<'a>,
pub cAuronFSDataInputWrapper: AuronFSDataInputWrapper<'a>,
pub cAuronFSDataOutputWrapper: AuronFSDataOutputWrapper<'a>,
pub cAuronJsonFallbackWrapper: AuronJsonFallbackWrapper<'a>,

pub cMetricNode: MetricNode<'a>,
pub cAuronUDFWrapperContext: AuronUDFWrapperContext<'a>,
}

Expand All @@ -481,6 +481,61 @@ impl JavaClasses<'static> {
)?
.l()?;

let engine_name_java = env
.call_static_method_unchecked(
jni_bridge.class,
jni_bridge.method_getEngineName,
jni_bridge.method_getEngineName_ret.clone(),
&[],
)?
.l()?;
let engine_name = env
.get_string(engine_name_java.into())
.map(|s| String::from(s))
.expect("engine_name is not valid");
log::info!("Runtime engine is {engine_name}");

let (
c_spark_file_segment,
c_spark_sql_metric,
c_spark_auron_udf_wrapper_context,
c_spark_udaf_wrapper_context,
c_spark_udtf_wrapper_context,
c_spark_udaf_mem_tracker,
c_auron_rss_partition_writer_base,
c_auron_on_heap_spill_manager,
c_auron_native_parquet_sink_utils,
c_auron_block_object,
c_auron_json_fallback_wrapper,
) = match engine_name.as_str() {
"Spark" => (
SparkFileSegment::new(env)?,
SparkSQLMetric::new(env)?,
SparkAuronUDFWrapperContext::new(env)?,
SparkUDAFWrapperContext::new(env)?,
SparkUDTFWrapperContext::new(env)?,
SparkUDAFMemTracker::new(env)?,
AuronRssPartitionWriterBase::new(env)?,
AuronOnHeapSpillManager::new(env)?,
AuronNativeParquetSinkUtils::new(env)?,
AuronBlockObject::new(env)?,
AuronJsonFallbackWrapper::new(env)?,
),
_ => (
SparkFileSegment::default(),
SparkSQLMetric::default(),
SparkAuronUDFWrapperContext::default(),
SparkUDAFWrapperContext::default(),
SparkUDTFWrapperContext::default(),
SparkUDAFMemTracker::default(),
AuronRssPartitionWriterBase::default(),
AuronOnHeapSpillManager::default(),
AuronNativeParquetSinkUtils::default(),
AuronBlockObject::default(),
AuronJsonFallbackWrapper::default(),
),
};

let java_classes = JavaClasses {
jvm: env.get_java_vm()?,
classloader: get_global_ref_jobject(env, classloader)?,
Expand All @@ -505,23 +560,23 @@ impl JavaClasses<'static> {
cHadoopFileSystem: HadoopFileSystem::new(env)?,
cHadoopPath: HadoopPath::new(env)?,

cSparkFileSegment: SparkFileSegment::new(env)?,
cSparkSQLMetric: SparkSQLMetric::new(env)?,
cSparkMetricNode: SparkMetricNode::new(env)?,
cSparkAuronUDFWrapperContext: SparkAuronUDFWrapperContext::new(env)?,
cSparkUDAFWrapperContext: SparkUDAFWrapperContext::new(env)?,
cSparkUDTFWrapperContext: SparkUDTFWrapperContext::new(env)?,
cSparkUDAFMemTracker: SparkUDAFMemTracker::new(env)?,
cAuronRssPartitionWriterBase: AuronRssPartitionWriterBase::new(env)?,
cAuronCallNativeWrapper: AuronCallNativeWrapper::new(env)?,
cAuronOnHeapSpillManager: AuronOnHeapSpillManager::new(env)?,
cAuronNativeParquetSinkUtils: AuronNativeParquetSinkUtils::new(env)?,
cAuronBlockObject: AuronBlockObject::new(env)?,
cSparkFileSegment: c_spark_file_segment,
cSparkSQLMetric: c_spark_sql_metric,
cSparkAuronUDFWrapperContext: c_spark_auron_udf_wrapper_context,
cSparkUDAFWrapperContext: c_spark_udaf_wrapper_context,
cSparkUDTFWrapperContext: c_spark_udtf_wrapper_context,
cSparkUDAFMemTracker: c_spark_udaf_mem_tracker,
cAuronRssPartitionWriterBase: c_auron_rss_partition_writer_base,
cAuronOnHeapSpillManager: c_auron_on_heap_spill_manager,
cAuronNativeParquetSinkUtils: c_auron_native_parquet_sink_utils,
cAuronBlockObject: c_auron_block_object,
cAuronJsonFallbackWrapper: c_auron_json_fallback_wrapper,

cAuronArrowFFIExporter: AuronArrowFFIExporter::new(env)?,
cAuronCallNativeWrapper: AuronCallNativeWrapper::new(env)?,
cAuronFSDataInputWrapper: AuronFSDataInputWrapper::new(env)?,
cAuronFSDataOutputWrapper: AuronFSDataOutputWrapper::new(env)?,
cAuronJsonFallbackWrapper: AuronJsonFallbackWrapper::new(env)?,

cMetricNode: MetricNode::new(env)?,
cAuronUDFWrapperContext: AuronUDFWrapperContext::new(env)?,
};
log::info!("Initializing JavaClasses finished");
Expand Down Expand Up @@ -587,6 +642,8 @@ pub struct JniBridge<'a> {
pub method_booleanConf_ret: ReturnType,
pub method_stringConf: JStaticMethodID,
pub method_stringConf_ret: ReturnType,
pub method_getEngineName: JStaticMethodID,
pub method_getEngineName_ret: ReturnType,
}
impl<'a> JniBridge<'a> {
pub const SIG_TYPE: &'static str = "org/apache/auron/jni/JniBridge";
Expand Down Expand Up @@ -700,6 +757,12 @@ impl<'a> JniBridge<'a> {
"(Ljava/lang/String;)Ljava/lang/String;",
)?,
method_stringConf_ret: ReturnType::Object,
method_getEngineName: env.get_static_method_id(
class,
"getEngineName",
"()Ljava/lang/String;",
)?,
method_getEngineName_ret: ReturnType::Object,
})
}
}
Expand Down Expand Up @@ -1110,6 +1173,10 @@ impl<'a> SparkFileSegment<'a> {
method_length_ret: ReturnType::Primitive(Primitive::Long),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand All @@ -1129,22 +1196,26 @@ impl<'a> SparkSQLMetric<'a> {
method_add_ret: ReturnType::Primitive(Primitive::Void),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
pub struct SparkMetricNode<'a> {
pub struct MetricNode<'a> {
pub class: JClass<'a>,
pub method_getChild: JMethodID,
pub method_getChild_ret: ReturnType,
pub method_add: JMethodID,
pub method_add_ret: ReturnType,
}
impl<'a> SparkMetricNode<'a> {
pub const SIG_TYPE: &'static str = "org/apache/auron/metric/SparkMetricNode";
impl<'a> MetricNode<'a> {
pub const SIG_TYPE: &'static str = "org/apache/auron/metric/MetricNode";

pub fn new(env: &JNIEnv<'a>) -> JniResult<SparkMetricNode<'a>> {
pub fn new(env: &JNIEnv<'a>) -> JniResult<MetricNode<'a>> {
let class = get_global_jclass(env, Self::SIG_TYPE)?;
Ok(SparkMetricNode {
Ok(MetricNode {
class,
method_getChild: env.get_method_id(
class,
Expand Down Expand Up @@ -1181,6 +1252,10 @@ impl<'a> AuronRssPartitionWriterBase<'_> {
method_flush_ret: ReturnType::Primitive(Primitive::Void),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand All @@ -1202,6 +1277,10 @@ impl<'a> SparkAuronUDFWrapperContext<'a> {
method_eval_ret: ReturnType::Primitive(Primitive::Void),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -1318,6 +1397,10 @@ impl<'a> SparkUDAFWrapperContext<'a> {
method_unspill_ret: ReturnType::Object,
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -1347,6 +1430,10 @@ impl<'a> SparkUDTFWrapperContext<'a> {
method_terminateLoop_ret: ReturnType::Primitive(Primitive::Void),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -1380,6 +1467,10 @@ impl<'a> SparkUDAFMemTracker<'a> {
method_updateUsed_ret: ReturnType::Primitive(Primitive::Boolean),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -1470,6 +1561,10 @@ impl<'a> AuronOnHeapSpillManager<'a> {
method_releaseSpill_ret: ReturnType::Primitive(Primitive::Void),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -1502,6 +1597,10 @@ impl<'a> AuronNativeParquetSinkUtils<'a> {
method_completeOutput_ret: ReturnType::Primitive(Primitive::Void),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -1562,6 +1661,10 @@ impl<'a> AuronBlockObject<'a> {
method_throwFetchFailed_ret: ReturnType::Primitive(Primitive::Void),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -1648,6 +1751,10 @@ impl<'a> AuronJsonFallbackWrapper<'a> {
method_parseJsons_ret: ReturnType::Primitive(Primitive::Void),
})
}

fn default() -> Self {
unsafe { std::mem::zeroed() }
}
}

fn get_global_jclass(env: &JNIEnv<'_>, cls: &str) -> JniResult<JClass<'static>> {
Expand Down
8 changes: 4 additions & 4 deletions native-engine/auron/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use auron_jni_bridge::{jni_call, jni_new_string};
use datafusion::{common::Result, physical_plan::ExecutionPlan};
use jni::objects::JObject;

pub fn update_spark_metric_node(
pub fn update_metric_node(
metric_node: JObject,
execution_plan: Arc<dyn ExecutionPlan>,
) -> Result<()> {
Expand All @@ -42,17 +42,17 @@ pub fn update_spark_metric_node(
// update children nodes
for (i, &child_plan) in execution_plan.children().iter().enumerate() {
let child_metric_node = jni_call!(
SparkMetricNode(metric_node).getChild(i as i32) -> JObject
MetricNode(metric_node).getChild(i as i32) -> JObject
)?;
update_spark_metric_node(child_metric_node.as_obj(), child_plan.clone())?;
update_metric_node(child_metric_node.as_obj(), child_plan.clone())?;
}
Ok(())
}

fn update_metrics(metric_node: JObject, metric_values: &[(&str, i64)]) -> Result<()> {
for &(name, value) in metric_values {
let jname = jni_new_string!(&name)?;
jni_call!(SparkMetricNode(metric_node).add(jname.as_obj(), value) -> ())?;
jni_call!(MetricNode(metric_node).add(jname.as_obj(), value) -> ())?;
}
Ok(())
}
4 changes: 2 additions & 2 deletions native-engine/auron/src/rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ use tokio::{runtime::Runtime, task::JoinHandle};
use crate::{
handle_unwinded_scope,
logging::{THREAD_PARTITION_ID, THREAD_STAGE_ID, THREAD_TID},
metrics::update_spark_metric_node,
metrics::update_metric_node,
};

pub struct NativeExecutionRuntime {
Expand Down Expand Up @@ -301,7 +301,7 @@ impl NativeExecutionRuntime {
let metrics = jni_call!(
AuronCallNativeWrapper(self.native_wrapper.as_obj()).getMetrics() -> JObject
)?;
update_spark_metric_node(metrics.as_obj(), self.plan.clone())?;
update_metric_node(metrics.as_obj(), self.plan.clone())?;
Ok(())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,9 @@ public AuronConfiguration getAuronConfiguration() {
public AuronUDFWrapperContext getAuronUDFWrapperContext(ByteBuffer udfSerialized) {
return new SparkAuronUDFWrapperContext(udfSerialized);
}

@Override
public String getEngineName() {
return "Spark";
}
}
Loading