Skip to content
Merged
115 changes: 114 additions & 1 deletion native-engine/blaze-jni-bridge/src/jni_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,44 @@ macro_rules! jni_get_byte_array_region {
$crate::jni_bridge::THREAD_JNIENV.with(|env| {
$crate::jni_map_error_with_env!(
env,
env.get_byte_array_region($value, $start as i32, unsafe {
env.get_byte_array_region($value.cast(), $start as i32, unsafe {
std::mem::transmute::<_, &mut [i8]>($buf)
})
)
})
}};
}

#[macro_export]
macro_rules! jni_get_byte_array_len {
($value:expr) => {{
$crate::jni_bridge::THREAD_JNIENV.with(|env| {
$crate::jni_map_error_with_env!(
env,
env.get_array_length($value.cast()).map(|s| s as usize)
)
})
}};
}

#[macro_export]
macro_rules! jni_new_prim_array {
($ty:ident, $value:expr) => {{
$crate::jni_bridge::THREAD_JNIENV.with(|env| {
$crate::jni_map_error_with_env!(
env,
paste::paste! {env.[<new_ $ty _array>]($value.len() as i32)}
.and_then(|array| {
let value = unsafe { std::mem::transmute($value) };
paste::paste! {env.[<set_ $ty _array_region>](array, 0, value)}
.map(|_| array)
})
.map(|s| $crate::jni_bridge::LocalRef(unsafe { JObject::from_raw(s.into()) }))
)
})
}};
}

#[macro_export]
macro_rules! jni_call {
($clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> JObject) => {{
Expand Down Expand Up @@ -410,6 +440,7 @@ pub struct JavaClasses<'a> {
pub cSparkSQLMetric: SparkSQLMetric<'a>,
pub cSparkMetricNode: SparkMetricNode<'a>,
pub cSparkUDFWrapperContext: SparkUDFWrapperContext<'a>,
pub cSparkUDAFWrapperContext: SparkUDAFWrapperContext<'a>,
pub cSparkUDTFWrapperContext: SparkUDTFWrapperContext<'a>,
pub cBlazeConf: BlazeConf<'a>,
pub cBlazeRssPartitionWriterBase: BlazeRssPartitionWriterBase<'a>,
Expand Down Expand Up @@ -471,6 +502,7 @@ impl JavaClasses<'static> {
cSparkSQLMetric: SparkSQLMetric::new(env)?,
cSparkMetricNode: SparkMetricNode::new(env)?,
cSparkUDFWrapperContext: SparkUDFWrapperContext::new(env)?,
cSparkUDAFWrapperContext: SparkUDAFWrapperContext::new(env)?,
cSparkUDTFWrapperContext: SparkUDTFWrapperContext::new(env)?,
cBlazeConf: BlazeConf::new(env)?,
cBlazeRssPartitionWriterBase: BlazeRssPartitionWriterBase::new(env)?,
Expand Down Expand Up @@ -1169,6 +1201,87 @@ impl<'a> SparkUDFWrapperContext<'a> {
}
}

#[allow(non_snake_case)]
pub struct SparkUDAFWrapperContext<'a> {
pub class: JClass<'a>,
pub ctor: JMethodID,
pub method_initialize: JMethodID,
pub method_initialize_ret: ReturnType,
pub method_resize: JMethodID,
pub method_resize_ret: ReturnType,
pub method_update: JMethodID,
pub method_update_ret: ReturnType,
pub method_merge: JMethodID,
pub method_merge_ret: ReturnType,
pub method_eval: JMethodID,
pub method_eval_ret: ReturnType,
pub method_serializeRows: JMethodID,
pub method_serializeRows_ret: ReturnType,
pub method_deserializeRows: JMethodID,
pub method_deserializeRows_ret: ReturnType,
pub method_memUsed: JMethodID,
pub method_memUsed_ret: ReturnType,
}
impl<'a> SparkUDAFWrapperContext<'a> {
pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/SparkUDAFWrapperContext";

pub fn new(env: &JNIEnv<'a>) -> JniResult<SparkUDAFWrapperContext<'a>> {
let class = get_global_jclass(env, Self::SIG_TYPE)?;
Ok(SparkUDAFWrapperContext {
class,
ctor: env.get_method_id(class, "<init>", "(Ljava/nio/ByteBuffer;)V")?,
method_initialize: env.get_method_id(
class,
"initialize",
"(I)Lorg/apache/spark/sql/blaze/BufferRowsColumn;",
)?,
method_initialize_ret: ReturnType::Object,
method_resize: env.get_method_id(
class,
"resize",
"(Lorg/apache/spark/sql/blaze/BufferRowsColumn;I)V",
)?,
method_resize_ret: ReturnType::Primitive(Primitive::Void),
method_update: env.get_method_id(
class,
"update",
"(Lorg/apache/spark/sql/blaze/BufferRowsColumn;J[J)V",
)?,
method_update_ret: ReturnType::Primitive(Primitive::Void),
method_merge: env.get_method_id(
class,
"merge",
"(Lorg/apache/spark/sql/blaze/BufferRowsColumn;Lorg/apache/spark/sql/blaze/BufferRowsColumn;[J)V",
)?,
method_merge_ret: ReturnType::Primitive(Primitive::Void),
method_eval: env.get_method_id(
class,
"eval",
"(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[IJ)V",
)?,
method_eval_ret: ReturnType::Primitive(Primitive::Void),
method_serializeRows: env.get_method_id(
class,
"serializeRows",
"(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[I)[B",
)?,
method_serializeRows_ret: ReturnType::Array,
method_deserializeRows: env.get_method_id(
class,
"deserializeRows",
"(Ljava/nio/ByteBuffer;)Lorg/apache/spark/sql/blaze/BufferRowsColumn;",
)?,
method_deserializeRows_ret: ReturnType::Object,
method_memUsed: env.get_method_id(
class,
"memUsed",
"(Lorg/apache/spark/sql/blaze/BufferRowsColumn;)I",
)?,
method_memUsed_ret: ReturnType::Primitive(Primitive::Int),
})
}
}

#[allow(non_snake_case)]
pub struct SparkUDTFWrapperContext<'a> {
pub class: JClass<'a>,
Expand Down
19 changes: 14 additions & 5 deletions native-engine/blaze-serde/proto/blaze.proto
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,20 @@ enum AggFunction {
BLOOM_FILTER = 9;
BRICKHOUSE_COLLECT = 1000;
BRICKHOUSE_COMBINE_UNIQUE = 1001;
UDAF = 1002;
}

message PhysicalAggExprNode {
AggFunction agg_function = 1;
repeated PhysicalExprNode children = 2;
AggUdaf udaf = 2;
repeated PhysicalExprNode children = 3;
}

message AggUdaf {
bytes serialized = 1;
Schema input_schema = 2;
ArrowType return_type = 3;
bool return_nullable = 4;
}

message PhysicalIsNull {
Expand Down Expand Up @@ -598,10 +607,10 @@ message FetchLimit {

message PhysicalRepartition {
oneof RepartitionType {
PhysicalSingleRepartition single_repartition = 1;
PhysicalHashRepartition hash_repartition = 2;
PhysicalRoundRobinRepartition round_robin_repartition = 3;
PhysicalRangeRepartition range_repartition = 4;
PhysicalSingleRepartition single_repartition = 1;
PhysicalHashRepartition hash_repartition = 2;
PhysicalRoundRobinRepartition round_robin_repartition = 3;
PhysicalRangeRepartition range_repartition = 4;
}
}

Expand Down
28 changes: 24 additions & 4 deletions native-engine/blaze-serde/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ use datafusion_ext_exprs::{
string_ends_with::StringEndsWithExpr, string_starts_with::StringStartsWithExpr,
};
use datafusion_ext_plans::{
agg::{agg::create_agg, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr},
agg::{
agg::{create_agg, create_udaf_agg},
AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr,
},
agg_exec::AggExec,
broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec,
broadcast_join_exec::BroadcastJoinExec,
Expand Down Expand Up @@ -437,13 +440,27 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.iter()
.map(|expr| try_parse_physical_expr(expr, &input_schema))
.collect::<Result<Vec<_>, _>>()?;

Ok(AggExpr {
agg: create_agg(
let agg = match AggFunction::from(agg_function) {
AggFunction::Udaf => {
let udaf = agg_node.udaf.as_ref().unwrap();
let serialized = udaf.serialized.clone();
let input_schema = Arc::new(convert_required!(udaf.input_schema)?);
create_udaf_agg(
serialized,
input_schema,
convert_required!(udaf.return_type)?,
agg_children_exprs,
)?
}
_ => create_agg(
AggFunction::from(agg_function),
&agg_children_exprs,
&input_schema,
)?,
};

Ok(AggExpr {
agg,
mode,
field_name: name.to_owned(),
})
Expand Down Expand Up @@ -556,6 +573,9 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
protobuf::AggFunction::BrickhouseCombineUnique => {
WindowFunction::Agg(AggFunction::BrickhouseCombineUnique)
}
protobuf::AggFunction::Udaf => {
WindowFunction::Agg(AggFunction::Udaf)
}
},
};
Ok::<_, Self::Error>(WindowExpr::new(window_func, children, field))
Expand Down
1 change: 1 addition & 0 deletions native-engine/blaze-serde/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ impl From<protobuf::AggFunction> for AggFunction {
protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter,
protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect,
protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique,
protobuf::AggFunction::Udaf => AggFunction::Udaf,
}
}
}
Expand Down
46 changes: 45 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion_ext_exprs::cast::TryCastExpr;

use crate::agg::{
acc::AccColumnRef, avg, bloom_filter, brickhouse, collect, first, first_ignores_null, maxmin,
sum, AggFunction,
spark_udaf_wrapper::SparkUDAFWrapper, sum, AggFunction,
};

pub trait Agg: Send + Sync + Debug {
Expand Down Expand Up @@ -76,6 +76,33 @@ impl IdxSelection<'_> {
IdxSelection::Range(begin, end) => end - begin,
}
}

pub fn to_int32_vec(&self) -> Vec<i32> {
let mut vec = Vec::with_capacity(self.len());

match self {
IdxSelection::Single(idx) => {
vec.push(*idx as i32);
}

IdxSelection::Indices(indices) => {
for &idx in *indices {
vec.push(idx as i32);
}
}
IdxSelection::IndicesU32(indices_u32) => {
for &idx in *indices_u32 {
vec.push(idx as i32);
}
}
IdxSelection::Range(start, end) => {
for idx in *start..*end {
vec.push(idx as i32);
}
}
}
vec
}
}

#[macro_export]
Expand Down Expand Up @@ -271,5 +298,22 @@ pub fn create_agg(
arg_list_inner_type,
)?)
}
AggFunction::Udaf => {
unreachable!("UDAF should be handled in create_udaf_agg")
}
})
}

pub fn create_udaf_agg(
serialized: Vec<u8>,
input_schema: SchemaRef,
return_type: DataType,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn Agg>> {
Ok(Arc::new(SparkUDAFWrapper::try_new(
serialized,
input_schema,
return_type,
children,
)?))
}
2 changes: 2 additions & 0 deletions native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ impl AggContext {
acc_idx,
&input_arrays,
IdxSelection::Range(0, batch.num_rows()),
batch.schema(),
)?;
}

Expand Down Expand Up @@ -327,6 +328,7 @@ impl AggContext {
acc_idx: IdxSelection,
input_arrays: &[Vec<ArrayRef>],
input_idx: IdxSelection,
batch_schema: SchemaRef,
) -> Result<()> {
if self.need_partial_update {
for (agg_idx, agg) in &self.need_partial_update_aggs {
Expand Down
2 changes: 2 additions & 0 deletions native-engine/datafusion-ext-plans/src/agg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub mod count;
pub mod first;
pub mod first_ignores_null;
pub mod maxmin;
mod spark_udaf_wrapper;
pub mod sum;

use std::{fmt::Debug, sync::Arc};
Expand Down Expand Up @@ -74,6 +75,7 @@ pub enum AggFunction {
BloomFilter,
BrickhouseCollect,
BrickhouseCombineUnique,
Udaf,
}

#[derive(Debug, Clone)]
Expand Down
Loading