Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.auron.functions;

/**
* Wrapper context for user-defined functions (UDFs).
* This class bridges different engines and native UDF implementations.
* SQL engines such as Spark and Flink should provide their respective implementations based on this.
*/
public interface AuronUDFWrapperContext {

/**
* Opens the UDF context with the given resource ID.
* The Flink engine requires the FunctionContext, which can be obtained via the ResourceID, to initialize the Flink ScalarFunction.
* @param resourceId
*/
default void open(String resourceId) {}

/**
* Evaluates the UDF with the provided input and output pointers.
* This method is called for each invocation of the UDF during query execution.
*
* @param inputPtr Native pointer to the input data
* @param outputPtr Native pointer to the output location where results should be written
*/
void eval(long inputPtr, long outputPtr);

/**
* Closes the UDF context.
* Some UDFs may need to perform resource cleanup operations.
*/
default void close() {}
}
11 changes: 11 additions & 0 deletions auron-core/src/main/java/org/apache/auron/jni/AuronAdaptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.auron.configuration.AuronConfiguration;
import org.apache.auron.functions.AuronUDFWrapperContext;
import org.apache.auron.memory.OnHeapSpillManager;

/**
Expand Down Expand Up @@ -119,4 +121,13 @@ public OnHeapSpillManager getOnHeapSpillManager() {
* Retrieves the AuronConfiguration, It bundles the corresponding engine's Config.
*/
public abstract AuronConfiguration getAuronConfiguration();

/**
* Retrieves the UDF wrapper context. Each engine requires its own implementation.
*
* @param udfSerialized The serialized UDF context.
* @return An instance of AuronUDFWrapperContext.
* @throws UnsupportedOperationException If the method is not implemented.
*/
public abstract AuronUDFWrapperContext getAuronUDFWrapperContext(ByteBuffer udfSerialized);
}
6 changes: 6 additions & 0 deletions auron-core/src/main/java/org/apache/auron/jni/JniBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import java.lang.management.BufferPoolMXBean;
import java.lang.management.ManagementFactory;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.auron.functions.AuronUDFWrapperContext;
import org.apache.auron.hadoop.fs.FSDataInputWrapper;
import org.apache.auron.hadoop.fs.FSDataOutputWrapper;
import org.apache.auron.memory.OnHeapSpillManager;
Expand Down Expand Up @@ -96,4 +98,8 @@ public static void setThreadContext(Object tc) {
public static String getDirectWriteSpillToDiskFile() throws IOException {
return AuronAdaptor.getInstance().getDirectWriteSpillToDiskFile();
}

public static AuronUDFWrapperContext getAuronUDFWrapperContext(ByteBuffer udfSerialized) {
return AuronAdaptor.getInstance().getAuronUDFWrapperContext(udfSerialized);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.auron.functions;

import java.nio.ByteBuffer;

/**
* Mock class for AuronUDFWrapperContext.
*/
public class MockAuronUDFWrapperContext implements AuronUDFWrapperContext {

public MockAuronUDFWrapperContext(ByteBuffer udfSerialized) {
// Mock implementation, We can obtain some information required for initializing the UDF through
// deserialization.
byte[] bytes = new byte[udfSerialized.remaining()];
udfSerialized.get(bytes);
// Deserialize the UDF information.
// get the UDF class name and initialize the UDF.
}

@Override
public void eval(long inputPtr, long outputPtr) {
// Mock implementation, we can use the inputPtr and outputPtr to process the data.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
*/
package org.apache.auron.jni;

import java.nio.ByteBuffer;
import org.apache.auron.configuration.AuronConfiguration;
import org.apache.auron.configuration.MockAuronConfiguration;
import org.apache.auron.functions.AuronUDFWrapperContext;
import org.apache.auron.functions.MockAuronUDFWrapperContext;

/**
* This is a mock class for testing the AuronAdaptor.
Expand All @@ -32,4 +35,9 @@ public void loadAuronLib() {
public AuronConfiguration getAuronConfiguration() {
return new MockAuronConfiguration();
}

@Override
public AuronUDFWrapperContext getAuronUDFWrapperContext(ByteBuffer udfSerialized) {
return new MockAuronUDFWrapperContext(udfSerialized);
}
}
47 changes: 40 additions & 7 deletions native-engine/auron-jni-bridge/src/jni_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ pub struct JavaClasses<'a> {
pub cSparkFileSegment: SparkFileSegment<'a>,
pub cSparkSQLMetric: SparkSQLMetric<'a>,
pub cSparkMetricNode: SparkMetricNode<'a>,
pub cSparkUDFWrapperContext: SparkUDFWrapperContext<'a>,
pub cSparkAuronUDFWrapperContext: SparkAuronUDFWrapperContext<'a>,
pub cSparkUDAFWrapperContext: SparkUDAFWrapperContext<'a>,
pub cSparkUDTFWrapperContext: SparkUDTFWrapperContext<'a>,
pub cSparkUDAFMemTracker: SparkUDAFMemTracker<'a>,
Expand All @@ -453,6 +453,8 @@ pub struct JavaClasses<'a> {
pub cAuronFSDataInputWrapper: AuronFSDataInputWrapper<'a>,
pub cAuronFSDataOutputWrapper: AuronFSDataOutputWrapper<'a>,
pub cAuronJsonFallbackWrapper: AuronJsonFallbackWrapper<'a>,

pub cAuronUDFWrapperContext: AuronUDFWrapperContext<'a>,
}

#[allow(clippy::non_send_fields_in_send_ty)]
Expand Down Expand Up @@ -503,7 +505,7 @@ impl JavaClasses<'static> {
cSparkFileSegment: SparkFileSegment::new(env)?,
cSparkSQLMetric: SparkSQLMetric::new(env)?,
cSparkMetricNode: SparkMetricNode::new(env)?,
cSparkUDFWrapperContext: SparkUDFWrapperContext::new(env)?,
cSparkAuronUDFWrapperContext: SparkAuronUDFWrapperContext::new(env)?,
cSparkUDAFWrapperContext: SparkUDAFWrapperContext::new(env)?,
cSparkUDTFWrapperContext: SparkUDTFWrapperContext::new(env)?,
cSparkUDAFMemTracker: SparkUDAFMemTracker::new(env)?,
Expand All @@ -517,6 +519,8 @@ impl JavaClasses<'static> {
cAuronFSDataInputWrapper: AuronFSDataInputWrapper::new(env)?,
cAuronFSDataOutputWrapper: AuronFSDataOutputWrapper::new(env)?,
cAuronJsonFallbackWrapper: AuronJsonFallbackWrapper::new(env)?,

cAuronUDFWrapperContext: AuronUDFWrapperContext::new(env)?,
};
log::info!("Initializing JavaClasses finished");
Ok(java_classes)
Expand Down Expand Up @@ -573,6 +577,9 @@ pub struct JniBridge<'a> {
pub method_getDirectWriteSpillToDiskFile_ret: ReturnType,
pub method_initNativeThread: JStaticMethodID,
pub method_initNativeThread_ret: ReturnType,

pub method_getAuronUDFWrapperContext: JStaticMethodID,
pub method_getAuronUDFWrapperContext_ret: ReturnType,
}
impl<'a> JniBridge<'a> {
pub const SIG_TYPE: &'static str = "org/apache/spark/sql/auron/JniBridge";
Expand Down Expand Up @@ -657,6 +664,13 @@ impl<'a> JniBridge<'a> {
"(Ljava/lang/ClassLoader;Lorg/apache/spark/TaskContext;)V",
)?,
method_initNativeThread_ret: ReturnType::Primitive(Primitive::Void),

method_getAuronUDFWrapperContext: env.get_static_method_id(
class,
"getAuronUDFWrapperContext",
"(Ljava/nio/ByteBuffer;)Lorg/apache/auron/functions/AuronUDFWrapperContext;",
)?,
method_getAuronUDFWrapperContext_ret: ReturnType::Object,
})
}
}
Expand Down Expand Up @@ -1193,18 +1207,18 @@ impl<'a> AuronRssPartitionWriterBase<'_> {
}

#[allow(non_snake_case)]
pub struct SparkUDFWrapperContext<'a> {
pub struct SparkAuronUDFWrapperContext<'a> {
pub class: JClass<'a>,
pub ctor: JMethodID,
pub method_eval: JMethodID,
pub method_eval_ret: ReturnType,
}
impl<'a> SparkUDFWrapperContext<'a> {
pub const SIG_TYPE: &'static str = "org/apache/spark/sql/auron/SparkUDFWrapperContext";
impl<'a> SparkAuronUDFWrapperContext<'a> {
pub const SIG_TYPE: &'static str = "org/apache/auron/spark/sql/SparkAuronUDFWrapperContext";

pub fn new(env: &JNIEnv<'a>) -> JniResult<SparkUDFWrapperContext<'a>> {
pub fn new(env: &JNIEnv<'a>) -> JniResult<SparkAuronUDFWrapperContext<'a>> {
let class = get_global_jclass(env, Self::SIG_TYPE)?;
Ok(SparkUDFWrapperContext {
Ok(SparkAuronUDFWrapperContext {
class,
ctor: env.get_method_id(class, "<init>", "(Ljava/nio/ByteBuffer;)V")?,
method_eval: env.get_method_id(class, "eval", "(JJ)V")?,
Expand All @@ -1213,6 +1227,25 @@ impl<'a> SparkUDFWrapperContext<'a> {
}
}

#[allow(non_snake_case)]
pub struct AuronUDFWrapperContext<'a> {
pub class: JClass<'a>,
pub method_eval: JMethodID,
pub method_eval_ret: ReturnType,
}
impl<'a> AuronUDFWrapperContext<'a> {
pub const SIG_TYPE: &'static str = "org/apache/auron/functions/AuronUDFWrapperContext";

pub fn new(env: &JNIEnv<'a>) -> JniResult<AuronUDFWrapperContext<'a>> {
let class = get_global_jclass(env, Self::SIG_TYPE)?;
Ok(AuronUDFWrapperContext {
class,
method_eval: env.get_method_id(class, "eval", "(JJ)V")?,
method_eval_ret: ReturnType::Primitive(Primitive::Void),
})
}
}

#[allow(non_snake_case)]
pub struct SparkUDAFWrapperContext<'a> {
pub class: JClass<'a>,
Expand Down
4 changes: 2 additions & 2 deletions native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl SparkUDFWrapperExpr {
.get_or_try_init(|| {
let serialized_buf = jni_new_direct_byte_buffer!(&self.serialized)?;
let jcontext_local =
jni_new_object!(SparkUDFWrapperContext(serialized_buf.as_obj()))?;
jni_new_object!(SparkAuronUDFWrapperContext(serialized_buf.as_obj()))?;
jni_new_global_ref!(jcontext_local.as_obj())
})
.cloned()
Expand Down Expand Up @@ -213,7 +213,7 @@ fn invoke_udf(
let struct_array = StructArray::from(params_batch);
let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data());
let mut import_ffi_array = FFI_ArrowArray::empty();
jni_call!(SparkUDFWrapperContext(jcontext.as_obj()).eval(
jni_call!(SparkAuronUDFWrapperContext(jcontext.as_obj()).eval(
&mut export_ffi_array as *mut FFI_ArrowArray as i64,
&mut import_ffi_array as *mut FFI_ArrowArray as i64,
) -> ())?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,13 @@ class AuronFunctionSuite
}
}
}

test("regexp_extract function with UDF failback") {
withTable("t1") {
sql("create table t1(c1 string) using parquet")
sql("insert into t1 values('Auron Spark SQL')")
val df = sql("select regexp_extract(c1, '^A(.*)L$', 1) from t1")
checkAnswer(df, Seq(Row("uron Spark SQ")))
}
}
}
5 changes: 5 additions & 0 deletions spark-extension/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
<packaging>jar</packaging>

<dependencies>
<dependency>
<groupId>org.apache.auron</groupId>
<artifactId>auron-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.auron</groupId>
<artifactId>hadoop-shim_${scalaVersion}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import java.lang.management.BufferPoolMXBean;
import java.lang.management.ManagementFactory;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.auron.functions.AuronUDFWrapperContext;
import org.apache.auron.hadoop.fs.FSDataInputWrapper;
import org.apache.auron.hadoop.fs.FSDataInputWrapper$;
import org.apache.auron.hadoop.fs.FSDataOutputWrapper;
Expand Down Expand Up @@ -121,4 +123,8 @@ public static void initNativeThread(ClassLoader cl, TaskContext tc) {
TaskContextHelper$.MODULE$.setNativeThreadName();
TaskContextHelper$.MODULE$.setHDFSCallerContext();
}

public static AuronUDFWrapperContext getAuronUDFWrapperContext(ByteBuffer udfSerialized) {
throw new UnsupportedOperationException("This API is designed to support next-generation multi-engine.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.auron
package org.apache.auron.spark.sql

import java.nio.ByteBuffer

import org.apache.arrow.c.ArrowArray
import org.apache.arrow.c.Data
import org.apache.arrow.c.{ArrowArray, Data}
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.auron.NativeConverters
import org.apache.spark.sql.auron.util.Using
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.expressions.Nondeterministic
import org.apache.spark.sql.execution.auron.arrowio.util.ArrowUtils
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, Nondeterministic}
import org.apache.spark.sql.execution.auron.arrowio.util.{ArrowUtils, ArrowWriter}
import org.apache.spark.sql.execution.auron.arrowio.util.ArrowUtils.ROOT_ALLOCATOR
import org.apache.spark.sql.execution.auron.arrowio.util.ArrowWriter
import org.apache.spark.sql.execution.auron.columnar.ColumnarHelper
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StructField, StructType}

case class SparkUDFWrapperContext(serialized: ByteBuffer) extends Logging {
import org.apache.auron.functions.AuronUDFWrapperContext

case class SparkAuronUDFWrapperContext(serialized: ByteBuffer)
extends AuronUDFWrapperContext
with Logging {
private val (expr, javaParamsSchema) =
NativeConverters.deserializeExpression[Expression, StructType]({
val bytes = new Array[Byte](serialized.remaining())
Expand Down