From f8f20978e0374beebeeb75c763e3a8d974fe9afe Mon Sep 17 00:00:00 2001 From: guoying06 Date: Fri, 24 Jan 2025 14:11:42 +0800 Subject: [PATCH 01/17] init DeclarativeAggregate udaf Resolving conflicted_file ArrowFFIExporter --- .../blaze-jni-bridge/src/jni_bridge.rs | 90 ++++ native-engine/blaze-serde/proto/blaze.proto | 11 +- native-engine/blaze-serde/src/from_proto.rs | 29 +- native-engine/blaze-serde/src/lib.rs | 1 + .../datafusion-ext-plans/src/agg/agg.rs | 54 ++- .../datafusion-ext-plans/src/agg/agg_ctx.rs | 10 +- .../datafusion-ext-plans/src/agg/avg.rs | 19 +- .../src/agg/bloom_filter.rs | 2 + .../src/agg/brickhouse/collect.rs | 3 + .../src/agg/brickhouse/combine_unique.rs | 3 + .../datafusion-ext-plans/src/agg/collect.rs | 1 + .../datafusion-ext-plans/src/agg/count.rs | 1 + .../datafusion-ext-plans/src/agg/first.rs | 1 + .../src/agg/first_ignores_null.rs | 1 + .../datafusion-ext-plans/src/agg/maxmin.rs | 1 + .../datafusion-ext-plans/src/agg/mod.rs | 2 + .../src/agg/spark_hdaf_wrapper.rs | 431 ++++++++++++++++++ .../datafusion-ext-plans/src/agg/sum.rs | 1 + .../src/window/processors/agg_processor.rs | 1 + .../sql/blaze/UnsafeRowsWrapperUtils.java | 40 ++ .../spark/sql/blaze/NativeConverters.scala | 41 ++ .../sql/blaze/SparkUDAFWrapperContext.scala | 203 +++++++++ .../spark/sql/blaze/UnsafeRowsWrapper.scala | 188 ++++++++ 23 files changed, 1122 insertions(+), 12 deletions(-) create mode 100644 native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs create mode 100644 spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java create mode 100644 spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala create mode 100644 spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index 260e76231..24c8e4e6f 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -410,10 +410,12 @@ pub struct JavaClasses<'a> { pub cSparkSQLMetric: SparkSQLMetric<'a>, pub cSparkMetricNode: SparkMetricNode<'a>, pub cSparkUDFWrapperContext: SparkUDFWrapperContext<'a>, + pub cSparkUDAFWrapperContext: SparkUDAFWrapperContext<'a>, pub cSparkUDTFWrapperContext: SparkUDTFWrapperContext<'a>, pub cBlazeConf: BlazeConf<'a>, pub cBlazeRssPartitionWriterBase: BlazeRssPartitionWriterBase<'a>, pub cBlazeCallNativeWrapper: BlazeCallNativeWrapper<'a>, + pub cBlazeUnsafeRowsWrapperUtils: BlazeUnsafeRowsWrapperUtils<'a>, pub cBlazeOnHeapSpillManager: BlazeOnHeapSpillManager<'a>, pub cBlazeNativeParquetSinkUtils: BlazeNativeParquetSinkUtils<'a>, pub cBlazeBlockObject: BlazeBlockObject<'a>, @@ -471,8 +473,10 @@ impl JavaClasses<'static> { cSparkSQLMetric: SparkSQLMetric::new(env)?, cSparkMetricNode: SparkMetricNode::new(env)?, cSparkUDFWrapperContext: SparkUDFWrapperContext::new(env)?, + cSparkUDAFWrapperContext: SparkUDAFWrapperContext::new(env)?, cSparkUDTFWrapperContext: SparkUDTFWrapperContext::new(env)?, cBlazeConf: BlazeConf::new(env)?, + cBlazeUnsafeRowsWrapperUtils: BlazeUnsafeRowsWrapperUtils::new(env)?, cBlazeRssPartitionWriterBase: BlazeRssPartitionWriterBase::new(env)?, cBlazeCallNativeWrapper: BlazeCallNativeWrapper::new(env)?, cBlazeOnHeapSpillManager: BlazeOnHeapSpillManager::new(env)?, @@ -1169,6 +1173,45 @@ impl<'a> SparkUDFWrapperContext<'a> { } } +#[allow(non_snake_case)] +pub struct SparkUDAFWrapperContext<'a> { + pub class: JClass<'a>, + pub ctor: JMethodID, + pub method_update: JMethodID, + pub method_update_ret: ReturnType, + pub method_merge: JMethodID, + pub method_merge_ret: ReturnType, + pub method_eval: JMethodID, + pub method_eval_ret: ReturnType, +} +impl<'a> SparkUDAFWrapperContext<'a> { + pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/SparkUDAFWrapperContext"; + + pub fn new(env: &JNIEnv<'a>) -> JniResult> { + let class = get_global_jclass(env, Self::SIG_TYPE)?; + Ok(SparkUDAFWrapperContext { + class, + ctor: env.get_method_id(class, "", "(Ljava/nio/ByteBuffer;)V")?, + method_update: env.get_method_id( + class, + "update", + "([Lorg/apache/spark/sql/catalyst/InternalRow;JJ;)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, + method_update_ret: ReturnType::Object, + method_merge: env.get_method_id( + class, + "merge", + "([Lorg/apache/spark/sql/catalyst/InternalRow;[Lorg/apache/spark/sql/catalyst/InternalRow;J;)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, + method_merge_ret: ReturnType::Object, + method_eval: env.get_method_id( + class, + "eval", + "([Lorg/apache/spark/sql/catalyst/InternalRow;JJ;)V")?, + method_eval_ret: ReturnType::Primitive(Primitive::Void), + + }) + } +} + #[allow(non_snake_case)] pub struct SparkUDTFWrapperContext<'a> { pub class: JClass<'a>, @@ -1194,6 +1237,53 @@ impl<'a> SparkUDTFWrapperContext<'a> { } } +#[allow(non_snake_case)] +pub struct BlazeUnsafeRowsWrapperUtils<'a> { + pub class: JClass<'a>, + pub method_serialize: JStaticMethodID, + pub method_serialize_ret: ReturnType, + pub method_deserialize: JStaticMethodID, + pub method_deserialize_ret: ReturnType, + pub method_num: JStaticMethodID, + pub method_num_ret: ReturnType, + pub method_create: JStaticMethodID, + pub method_create_ret: ReturnType, +} +impl<'a> BlazeUnsafeRowsWrapperUtils<'a> { + pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils;"; + + pub fn new(env: &JNIEnv<'a>) -> JniResult> { + let class = get_global_jclass(env, Self::SIG_TYPE)?; + Ok(BlazeUnsafeRowsWrapperUtils { + class, + method_serialize: env.get_static_method_id( + class, + "serialize", + "([Lorg/apache/spark/sql/catalyst/InternalRow;IJJ)V", + )?, + method_serialize_ret: ReturnType::Primitive(Primitive::Void), + method_deserialize: env.get_static_method_id( + class, + "deserialize", + "(IJJ;)[Lorg/apache/spark/sql/catalyst/InternalRow;", + )?, + method_deserialize_ret: ReturnType::Object, + method_num: env.get_static_method_id( + class, + "getRowNum", + "([Lorg/apache/spark/sql/catalyst/InternalRow;)I;", + )?, + method_num_ret: ReturnType::Primitive(Primitive::Int), + method_create: env.get_static_method_id( + class, + "getEmptyObject", + "(I;)[Lorg/apache/spark/sql/catalyst/InternalRow;", + )?, + method_create_ret: ReturnType::Object, + }) + } +} + #[allow(non_snake_case)] pub struct BlazeCallNativeWrapper<'a> { pub class: JClass<'a>, diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index 283e53ff2..4d9efb802 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto @@ -141,11 +141,20 @@ enum AggFunction { BLOOM_FILTER = 9; BRICKHOUSE_COLLECT = 1000; BRICKHOUSE_COMBINE_UNIQUE = 1001; + DECLARATIVE = 1002; } message PhysicalAggExprNode { AggFunction agg_function = 1; - repeated PhysicalExprNode children = 2; + AggUdaf udaf = 2; + repeated PhysicalExprNode children = 3; +} + +message AggUdaf { + bytes serialized = 1; + Schema agg_buffer_schema = 2; + ArrowType return_type = 3; + bool return_nullable = 4; } message PhysicalIsNull { diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index 3cf83c109..dcc6d133d 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs @@ -59,7 +59,10 @@ use datafusion_ext_exprs::{ string_ends_with::StringEndsWithExpr, string_starts_with::StringStartsWithExpr, }; use datafusion_ext_plans::{ - agg::{agg::create_agg, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr}, + agg::{ + agg::{create_agg, create_declarative_agg}, + AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr, + }, agg_exec::AggExec, broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec, broadcast_join_exec::BroadcastJoinExec, @@ -437,13 +440,28 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .iter() .map(|expr| try_parse_physical_expr(expr, &input_schema)) .collect::, _>>()?; - - Ok(AggExpr { - agg: create_agg( + let agg = match AggFunction::from(agg_function) { + AggFunction::Declarative => { + let udaf = agg_node.udaf.as_ref().unwrap(); + let serialized = udaf.serialized.clone(); + let agg_buffer_schema = + Arc::new(convert_required!(udaf.agg_buffer_schema)?); + create_declarative_agg( + serialized, + agg_buffer_schema, + convert_required!(udaf.return_type)?, + agg_children_exprs, + )? + } + _ => create_agg( AggFunction::from(agg_function), &agg_children_exprs, &input_schema, )?, + }; + + Ok(AggExpr { + agg, mode, field_name: name.to_owned(), }) @@ -556,6 +574,9 @@ impl TryInto> for &protobuf::PhysicalPlanNode { protobuf::AggFunction::BrickhouseCombineUnique => { WindowFunction::Agg(AggFunction::BrickhouseCombineUnique) } + protobuf::AggFunction::Declarative => { + WindowFunction::Agg(AggFunction::Declarative) + } }, }; Ok::<_, Self::Error>(WindowExpr::new(window_func, children, field)) diff --git a/native-engine/blaze-serde/src/lib.rs b/native-engine/blaze-serde/src/lib.rs index 223496bb1..b64167a42 100644 --- a/native-engine/blaze-serde/src/lib.rs +++ b/native-engine/blaze-serde/src/lib.rs @@ -138,6 +138,7 @@ impl From for AggFunction { protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter, protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect, protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique, + protobuf::AggFunction::Declarative => AggFunction::Declarative, } } } diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 234b97db2..c430e83c3 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -15,7 +15,7 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use arrow::{ - array::{ArrayRef, AsArray, RecordBatch}, + array::{Array, ArrayRef, AsArray, Int64Array, Int64Builder, RecordBatch}, datatypes::{DataType, Int64Type, Schema, SchemaRef}, }; use datafusion::{common::Result, physical_expr::PhysicalExpr}; @@ -24,7 +24,7 @@ use datafusion_ext_exprs::cast::TryCastExpr; use crate::agg::{ acc::AccColumnRef, avg, bloom_filter, brickhouse, collect, first, first_ignores_null, maxmin, - sum, AggFunction, + spark_hdaf_wrapper::SparkUDAFWrapper, sum, AggFunction, }; pub trait Agg: Send + Sync + Debug { @@ -46,6 +46,7 @@ pub trait Agg: Send + Sync + Debug { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()>; fn partial_merge( @@ -76,6 +77,38 @@ impl IdxSelection<'_> { IdxSelection::Range(begin, end) => end - begin, } } + + pub fn to_int64_array(&self) -> Int64Array { + let mut builder = Int64Builder::with_capacity(self.len()); + + match self { + IdxSelection::Single(idx) => { + builder.append_value(*idx as i64); + } + + IdxSelection::Indices(indices) => { + for &idx in *indices { + builder.append_value(idx as i64); + } + } + IdxSelection::IndicesU32(indices_u32) => { + for &idx in *indices_u32 { + builder.append_value(idx as i64); + } + } + IdxSelection::Range(start, end) => { + for idx in *start..=*end { + builder.append_value(idx as i64); + } + } + } + let primitive_array = builder.finish(); + primitive_array + .as_any() + .downcast_ref::() + .cloned() + .unwrap() + } } #[macro_export] @@ -271,5 +304,22 @@ pub fn create_agg( arg_list_inner_type, )?) } + AggFunction::Declarative => { + unreachable!("UDAF should be handled in create_declarative_agg") + } }) } + +pub fn create_declarative_agg( + serialized: Vec, + buffer_schema: SchemaRef, + return_type: DataType, + children: Vec>, +) -> Result> { + Ok(Arc::new(SparkUDAFWrapper::try_new( + serialized, + buffer_schema, + return_type, + children, + )?)) +} diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs b/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs index b3b41e04d..8d39dd1db 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs @@ -249,6 +249,7 @@ impl AggContext { acc_idx, &input_arrays, IdxSelection::Range(0, batch.num_rows()), + batch.schema(), )?; } @@ -327,11 +328,18 @@ impl AggContext { acc_idx: IdxSelection, input_arrays: &[Vec], input_idx: IdxSelection, + batch_schema: SchemaRef, ) -> Result<()> { if self.need_partial_update { for (agg_idx, agg) in &self.need_partial_update_aggs { let acc_col = &mut acc_table.cols_mut()[*agg_idx]; - agg.partial_update(acc_col, acc_idx, &input_arrays[*agg_idx], input_idx)?; + agg.partial_update( + acc_col, + acc_idx, + &input_arrays[*agg_idx], + input_idx, + batch_schema.clone(), + )?; } } Ok(()) diff --git a/native-engine/datafusion-ext-plans/src/agg/avg.rs b/native-engine/datafusion-ext-plans/src/agg/avg.rs index 7e2496cbe..58fe571f8 100644 --- a/native-engine/datafusion-ext-plans/src/agg/avg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/avg.rs @@ -110,12 +110,23 @@ impl Agg for AggAvg { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccAvgColumn).unwrap(); - self.agg_sum - .partial_update(&mut accs.sum, acc_idx, partial_args, partial_arg_idx)?; - self.agg_count - .partial_update(&mut accs.count, acc_idx, partial_args, partial_arg_idx)?; + self.agg_sum.partial_update( + &mut accs.sum, + acc_idx, + partial_args, + partial_arg_idx, + batch_schema.clone(), + )?; + self.agg_count.partial_update( + &mut accs.count, + acc_idx, + partial_args, + partial_arg_idx, + batch_schema.clone(), + )?; Ok(()) } diff --git a/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs b/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs index f4bd34580..9064e611c 100644 --- a/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs +++ b/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs @@ -23,6 +23,7 @@ use arrow::{ array::{ArrayRef, AsArray, BinaryBuilder}, datatypes::{DataType, Int64Type}, }; +use arrow_schema::SchemaRef; use byteorder::{ReadBytesExt, WriteBytesExt}; use datafusion::{common::Result, physical_expr::PhysicalExpr}; use datafusion_ext_commons::{ @@ -113,6 +114,7 @@ impl Agg for AggBloomFilter { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccBloomFilterColumn).unwrap(); let bloom_filter = match acc_idx { diff --git a/native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs b/native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs index c74cf5dd7..5216df746 100644 --- a/native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs +++ b/native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs @@ -22,6 +22,7 @@ use arrow::{ array::{Array, ArrayRef, AsArray}, datatypes::DataType, }; +use arrow_schema::SchemaRef; use datafusion::{common::Result, physical_expr::PhysicalExpr}; use crate::{ @@ -86,6 +87,7 @@ impl Agg for AggCollect { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let list = partial_args[0].as_list::(); @@ -99,6 +101,7 @@ impl Agg for AggCollect { IdxSelection::Single(acc_idx), &[values], IdxSelection::Range(0, values_len), + batch_schema.clone(), )?; } } diff --git a/native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs b/native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs index 1b8b8246f..900f2b43c 100644 --- a/native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs +++ b/native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs @@ -22,6 +22,7 @@ use arrow::{ array::{Array, ArrayRef, AsArray}, datatypes::DataType, }; +use arrow_schema::SchemaRef; use datafusion::{common::Result, physical_expr::PhysicalExpr}; use crate::{ @@ -86,6 +87,7 @@ impl Agg for AggCombineUnique { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let list = partial_args[0].as_list::(); @@ -99,6 +101,7 @@ impl Agg for AggCombineUnique { IdxSelection::Single(acc_idx), &[values], IdxSelection::Range(0, values_len), + batch_schema.clone(), )?; } } diff --git a/native-engine/datafusion-ext-plans/src/agg/collect.rs b/native-engine/datafusion-ext-plans/src/agg/collect.rs index de2dca821..07e90a3b0 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect.rs @@ -114,6 +114,7 @@ impl Agg for AggGenericCollect { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut C).unwrap(); idx_for_zipped! { diff --git a/native-engine/datafusion-ext-plans/src/agg/count.rs b/native-engine/datafusion-ext-plans/src/agg/count.rs index e90ff2fcd..19d8a4d90 100644 --- a/native-engine/datafusion-ext-plans/src/agg/count.rs +++ b/native-engine/datafusion-ext-plans/src/agg/count.rs @@ -92,6 +92,7 @@ impl Agg for AggCount { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccCountColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/first.rs b/native-engine/datafusion-ext-plans/src/agg/first.rs index a7083b53b..0190c618b 100644 --- a/native-engine/datafusion-ext-plans/src/agg/first.rs +++ b/native-engine/datafusion-ext-plans/src/agg/first.rs @@ -90,6 +90,7 @@ impl Agg for AggFirst { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let partial_arg = &partial_args[0]; let accs = downcast_any!(accs, mut AccFirstColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs b/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs index ca898a5b1..e211066ee 100644 --- a/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs +++ b/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs @@ -86,6 +86,7 @@ impl Agg for AggFirstIgnoresNull { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let partial_arg = &partial_args[0]; let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs index 24800b1c4..e1571d708 100644 --- a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs +++ b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs @@ -93,6 +93,7 @@ impl Agg for AggMaxMin

{ acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); let old_heap_mem_used = accs.items_heap_mem_used(acc_idx); diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index f4a93133b..ba4d5c0ca 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -25,6 +25,7 @@ pub mod count; pub mod first; pub mod first_ignores_null; pub mod maxmin; +mod spark_hdaf_wrapper; pub mod sum; use std::{fmt::Debug, sync::Arc}; @@ -74,6 +75,7 @@ pub enum AggFunction { BloomFilter, BrickhouseCollect, BrickhouseCombineUnique, + Declarative, } #[derive(Debug, Clone)] diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs new file mode 100644 index 000000000..77ce80805 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs @@ -0,0 +1,431 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +use arrow::{ + array::{ + as_struct_array, make_array, Array, ArrayAccessor, ArrayRef, AsArray, BinaryArray, Datum, + Int32Array, StructArray, + }, + buffer::NullBuffer, + datatypes::{DataType, Field, Schema, SchemaRef}, + ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + record_batch::{RecordBatch, RecordBatchOptions}, +}; +use arrow_schema::Fields; +use blaze_jni_bridge::{ + jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, +}; +use datafusion::{ + common::{DataFusionError, Result}, + physical_expr::PhysicalExpr, +}; +use datafusion_ext_commons::downcast_any; +use jni::objects::{GlobalRef, JObject}; +use once_cell::sync::OnceCell; + +use crate::{ + agg::{ + acc::{AccColumn, AccColumnRef}, + agg::{Agg, IdxSelection}, + }, + memmgr::spill::{SpillCompressedReader, SpillCompressedWriter}, +}; + +pub struct SparkUDAFWrapper { + serialized: Vec, + pub buffer_schema: SchemaRef, + pub return_type: DataType, + child: Vec>, + import_schema: SchemaRef, + params_schema: OnceCell, + jcontext: OnceCell, +} + +impl SparkUDAFWrapper { + pub fn try_new( + serialized: Vec, + buffer_schema: SchemaRef, + return_type: DataType, + child: Vec>, + ) -> Result { + Ok(Self { + serialized, + buffer_schema, + return_type: return_type.clone(), + child, + import_schema: Arc::new(Schema::new(vec![Field::new("", return_type, true)])), + params_schema: OnceCell::new(), + jcontext: OnceCell::new(), + }) + } + + fn jcontext(&self) -> Result { + self.jcontext + .get_or_try_init(|| { + let serialized_buf = jni_new_direct_byte_buffer!(&self.serialized)?; + let jcontext_local = + jni_new_object!(SparkUDAFWrapperContext(serialized_buf.as_obj()))?; + jni_new_global_ref!(jcontext_local.as_obj()) + }) + .cloned() + } +} + +impl Display for SparkUDAFWrapper { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "SparkUDAFWrapper") + } +} + +impl Debug for SparkUDAFWrapper { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "SparkUDAFWrapper({:?})", self.child) + } +} + +impl Agg for SparkUDAFWrapper { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec> { + self.child.clone() + } + + fn data_type(&self) -> &DataType { + &self.return_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + // num_rows + let rows = jni_call_static!( + BlazeUnsafeRowsWrapperUtils.create(num_rows as i32)-> JObject) + .unwrap(); + let obj = jni_new_global_ref!(rows.as_obj()).unwrap(); + Box::new(AccUnsafeRowsColumn { + obj, + num_fields: self.buffer_schema.fields.len(), + }) + } + + fn with_new_exprs(&self, exprs: Vec>) -> Result> { + Ok(Arc::new(Self::try_new( + self.serialized.clone(), + self.buffer_schema.clone(), + self.return_type.clone(), + self.child.clone(), + )?)) + } + + // todo: implemented prepare_partial_args + // fn prepare_partial_args(&self, partial_inputs: &[ArrayRef]) -> + // Result> { // cast arg1 to target data type + // Ok(vec![datafusion_ext_commons::arrow::cast::cast( + // &partial_inputs[0], + // &self.return_type, + // )?]) + // } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, + ) -> Result<()> { + let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); + + let params = partial_args.to_vec(); + let params_schema = self + .params_schema + .get_or_try_init(|| -> Result { + let mut param_fields = Vec::with_capacity(self.child.len()); + for child in &self.child { + param_fields.push(Field::new( + "", + child.data_type(batch_schema.as_ref())?, + child.nullable(batch_schema.as_ref())?, + )); + } + Ok(Arc::new(Schema::new(param_fields))) + })?; + let params_batch = RecordBatch::try_new_with_options( + params_schema.clone(), + params.clone(), + &RecordBatchOptions::new().with_row_count(Some(params[0].len())), + )?; + + accs.obj = partial_update_udaf( + self.jcontext()?, + params_batch, + accs.obj.clone(), + acc_idx, + partial_arg_idx, + ) + .unwrap(); + + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); + let merging_accs = downcast_any!(merging_accs, mut AccUnsafeRowsColumn).unwrap(); + + accs.obj = partial_merge_udaf( + self.jcontext()?, + accs.obj.clone(), + merging_accs.obj.clone(), + acc_idx, + merging_acc_idx, + ) + .unwrap(); + + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); + final_merge_udaf( + self.jcontext()?, + accs.obj.clone(), + acc_idx, + self.import_schema.clone(), + ) + } +} + +struct AccUnsafeRowsColumn { + obj: GlobalRef, + num_fields: usize, +} + +impl AccColumn for AccUnsafeRowsColumn { + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn resize(&mut self, len: usize) { + unimplemented!() + } + + fn shrink_to_fit(&mut self) { + unimplemented!() + } + + fn num_records(&self) -> usize { + match jni_call_static!( + BlazeUnsafeRowsWrapperUtils.num(self.obj.as_obj()) + -> i32) + { + Ok(row_num) => row_num as usize, + Err(_) => 0, + } + } + + fn mem_used(&self) -> usize { + 0 + } + + fn freeze_to_rows(&self, idx: IdxSelection<'_>, array: &mut [Vec]) -> Result<()> { + let field = Arc::new(Field::new("", DataType::Int64, false)); + let idx64 = idx.to_int64_array().into_data(); + let struct_array = StructArray::from(vec![(field, make_array(idx64))]); + let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); + let mut import_ffi_array = FFI_ArrowArray::empty(); + jni_call_static!( + BlazeUnsafeRowsWrapperUtils.serialize( + self.obj.as_obj(), + self.num_fields as i32, + &mut export_ffi_array as *mut FFI_ArrowArray as i64, + &mut import_ffi_array as *mut FFI_ArrowArray as i64,) + -> ())?; + // import output from context + let field = Field::new("", DataType::Binary, false); + let schema = Schema::new(vec![field]); + let import_ffi_schema = FFI_ArrowSchema::try_from(schema)?; + let import_struct_array = + make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); + let result_struct = import_struct_array.as_struct(); + + let binary_array = result_struct + .column(1) + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Execution("Expected a BinaryArray".to_string()))?; + + for i in 0..binary_array.len() { + if binary_array.is_valid(i) { + let bytes = binary_array.value(i).to_vec(); + array[i] = bytes; + } else { + log::warn!("AccUnsafeRowsColumn::freeze_to_rows : binary_array null error") + } + } + Ok(()) + } + + fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> { + let fields = Fields::from(vec![ + Field::new("", DataType::Binary, false), + Field::new("", DataType::Int64, false), + ]); + let binary_values = array.iter().map(|&data| data).collect(); + let offsets_i32 = offsets + .iter() + .map(|data| *data as i32) + .collect::>(); + let offsets_array = Int32Array::from(offsets_i32); + let binary_array = BinaryArray::from_vec(binary_values); + let nulls = Some(NullBuffer::from_iter([false, false])); + let values = vec![Arc::new(binary_array) as _, Arc::new(offsets_array) as _]; + let struct_array = StructArray::new(fields, values, nulls); + let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); + let mut import_ffi_array = FFI_ArrowArray::empty(); + let rows = jni_call_static!( + BlazeUnsafeRowsWrapperUtils.deserialize( + self.num_fields as i32, + &mut export_ffi_array as *mut FFI_ArrowArray as i64, + &mut import_ffi_array as *mut FFI_ArrowArray as i64,) + -> JObject)?; + self.obj = jni_new_global_ref!(rows.as_obj())?; + + // update offsets + // import output from context + let field = Field::new("", DataType::Int32, false); + let schema = Schema::new(vec![field]); + let import_ffi_schema = FFI_ArrowSchema::try_from(schema)?; + let import_struct_array = + make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); + let result_struct = import_struct_array.as_struct(); + + let int32array = result_struct + .column(1) + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Execution("Expected a Int32Array".to_string()))?; + + assert_eq!(int32array.len(), array.len()); + + for i in 0..int32array.len() { + offsets[i] = int32array.value(i) as usize; + } + + Ok(()) + } + + fn spill(&self, idx: IdxSelection<'_>, buf: &mut SpillCompressedWriter) -> Result<()> { + unimplemented!() + } + + fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> { + unimplemented!() + } +} + +fn partial_update_udaf( + jcontext: GlobalRef, + params_batch: RecordBatch, + accs: GlobalRef, + acc_idx: IdxSelection<'_>, + partial_arg_idx: IdxSelection<'_>, +) -> Result { + let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int64, false)); + let partial_arg_idx_field = Arc::new(Field::new("partial_arg_idx", DataType::Int64, false)); + let acc_idx = acc_idx.to_int64_array().into_data(); + let partial_arg_idx = partial_arg_idx.to_int64_array().into_data(); + let struct_array = StructArray::from(vec![ + (acc_idx_field.clone(), make_array(acc_idx)), + (partial_arg_idx_field.clone(), make_array(partial_arg_idx)), + ]); + let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); + + let struct_array = StructArray::from(params_batch); + let mut export_ffi_batch_array = FFI_ArrowArray::new(&struct_array.to_data()); + + let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).update( + accs.as_obj(), + &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, + &mut export_ffi_batch_array as *mut FFI_ArrowArray as i64, + )-> JObject)?; + + jni_new_global_ref!(rows.as_obj()) +} + +fn partial_merge_udaf( + jcontext: GlobalRef, + accs: GlobalRef, + merging_accs: GlobalRef, + acc_idx: IdxSelection<'_>, + merging_acc_idx: IdxSelection<'_>, +) -> Result { + let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int64, false)); + let merging_acc_idx_field = Arc::new(Field::new("merging_acc_idx", DataType::Int64, false)); + let acc_idx = acc_idx.to_int64_array().into_data(); + let merging_acc_idx = merging_acc_idx.to_int64_array().into_data(); + let struct_array = StructArray::from(vec![ + (acc_idx_field.clone(), make_array(acc_idx)), + (merging_acc_idx_field.clone(), make_array(merging_acc_idx)), + ]); + let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); + + let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).merge( + accs.as_obj(), + merging_accs.as_obj(), + &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, + )-> JObject)?; + + jni_new_global_ref!(rows.as_obj()) +} + +fn final_merge_udaf( + jcontext: GlobalRef, + accs: GlobalRef, + acc_idx: IdxSelection<'_>, + result_schema: SchemaRef, +) -> Result { + let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int64, false)); + let acc_idx = acc_idx.to_int64_array().into_data(); + let struct_array = StructArray::from(vec![(acc_idx_field.clone(), make_array(acc_idx))]); + let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); + let mut import_ffi_array = FFI_ArrowArray::empty(); + let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).eval( + accs.as_obj(), + &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, + &mut import_ffi_array as *mut FFI_ArrowArray as i64, + )-> ())?; + + // import output from context + let import_ffi_schema = FFI_ArrowSchema::try_from(result_schema.as_ref())?; + let import_struct_array = + make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); + let import_array = as_struct_array(&import_struct_array).column(0).clone(); + Ok(import_array) +} diff --git a/native-engine/datafusion-ext-plans/src/agg/sum.rs b/native-engine/datafusion-ext-plans/src/agg/sum.rs index f0b0f4f32..49686f78d 100644 --- a/native-engine/datafusion-ext-plans/src/agg/sum.rs +++ b/native-engine/datafusion-ext-plans/src/agg/sum.rs @@ -91,6 +91,7 @@ impl Agg for AggSum { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, + batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs index 89d2b56ba..af2f75943 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs @@ -78,6 +78,7 @@ impl WindowFunctionProcessor for AggProcessor { IdxSelection::Single(0), &children_cols, IdxSelection::Single(row_idx), + batch.schema(), )?; output.push( self.agg diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java new file mode 100644 index 000000000..a3c2c76b8 --- /dev/null +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java @@ -0,0 +1,40 @@ +/* + * Copyright 2022 The Blaze Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.blaze; + +import org.apache.spark.sql.catalyst.InternalRow; + +// for jni_bridge usage + +public class UnsafeRowsWrapperUtils { + + public static void serialize( + InternalRow[] unsafeRows, int numFields, Long importFFIArrayPtr, Long exportFFIArrayPtr) { + UnsafeRowsWrapper$.MODULE$.serialize(unsafeRows, numFields, importFFIArrayPtr, exportFFIArrayPtr); + } + + public static InternalRow[] deserialize(int numFields, Long importFFIArrayPtr, Long exportFFIArrayPtr) { + return UnsafeRowsWrapper$.MODULE$.deserialize(numFields, importFFIArrayPtr, exportFFIArrayPtr); + } + + public static int getRowNum(InternalRow[] unsafeRows) { + return UnsafeRowsWrapper$.MODULE$.getRowNum(unsafeRows); + } + + public static InternalRow[] getEmptyObject(int rowNum) { + return UnsafeRowsWrapper$.MODULE$.getNullObject(rowNum); + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index a7d4d9f49..c4011ac27 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.expressions.aggregate.Max import org.apache.spark.sql.catalyst.expressions.aggregate.Min import org.apache.spark.sql.catalyst.expressions.aggregate.Sum +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext @@ -1164,6 +1165,46 @@ object NativeConverters extends Logging { defaultValue = true) => aggBuilder.setAggFunction(pb.AggFunction.BRICKHOUSE_COMBINE_UNIQUE) aggBuilder.addChildren(convertExpr(udaf.children.head)) + // other DeclarativeAggregate + case declarative + if classOf[DeclarativeAggregate].isAssignableFrom(e.aggregateFunction.getClass) => + def fallbackToError: Expression => pb.PhysicalExprNode = { e => + throw new NotImplementedError(s"unsupported declarative expression: (${e.getClass}) $e") + } + aggBuilder.setAggFunction(pb.AggFunction.DECLARATIVE) + val convertedChildren = mutable.LinkedHashMap[pb.PhysicalExprNode, BoundReference]() + val bound = declarative.mapChildren(_.transformDown { + case p: Literal => p + case p => + try { + val convertedChild = + convertExprWithFallback(p, isPruningExpr = false, fallbackToError) + val nextBindIndex = convertedChildren.size + convertedChildren.getOrElseUpdate( + convertedChild, + BoundReference(nextBindIndex, p.dataType, p.nullable)) + } catch { + case _: Exception | _: NotImplementedError => p + } + }) + val paramsSchema = StructType( + convertedChildren.values + .map(ref => StructField("", ref.dataType, ref.nullable)) + .toSeq) + + val serialized = + serializeExpression( + bound.asInstanceOf[DeclarativeAggregate with Serializable], + paramsSchema) + + aggBuilder.setUdaf( + pb.AggUdaf + .newBuilder() + .setSerialized(ByteString.copyFrom(serialized)) + .setAggBufferSchema(NativeConverters.convertSchema(declarative.aggBufferSchema)) + .setReturnType(convertDataType(bound.dataType)) + .setReturnNullable(bound.nullable)) + aggBuilder.addAllChildren(convertedChildren.keys.asJava) case _ => Shims.get.convertMoreAggregateExpr(e) match { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala new file mode 100644 index 000000000..6d173b77a --- /dev/null +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -0,0 +1,203 @@ +/* + * Copyright 2022 The Blaze Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.blaze + +import scala.collection.JavaConverters._ +import org.apache.arrow.c.{ArrowArray, Data} +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.dictionary.DictionaryProvider +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.blaze.util.Using +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, UnsafeProjection} +import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper +import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import com.google.flatbuffers.LongVector +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, MutableProjection} +import org.apache.spark.sql.catalyst.expressions.AttributeReference + +import java.nio.ByteBuffer + +case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { + private val (expr, javaParamsSchema) = + NativeConverters.deserializeExpression[DeclarativeAggregate]({ + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + bytes + }) + + val inputAttributes: Seq[Attribute] = javaParamsSchema.fields.map { field => + AttributeReference(field.name, field.dataType, field.nullable)() + } + + // initialize all nondeterministic children exprs + expr.foreach { + case nondeterministic: Nondeterministic => + nondeterministic.initialize(TaskContext.get.partitionId()) + case _ => + } + + private lazy val initializer = MutableProjection.create(expr.initialValues) + + private lazy val updater = + MutableProjection.create(expr.updateExpressions, expr.aggBufferAttributes ++ inputAttributes) + + private lazy val merger = MutableProjection.create( + expr.mergeExpressions, + expr.aggBufferAttributes ++ expr.inputAggBufferAttributes) + + private lazy val evaluator = + MutableProjection.create(expr.evaluateExpression :: Nil, expr.aggBufferAttributes) + + private def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() + + private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() + + private val inputSchema = ArrowUtils.toArrowSchema(javaParamsSchema) + + { + val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable))) + ArrowUtils.toArrowSchema(schema) + } + + private val indexSchema = { + val schema = StructType(Seq(StructField("", LongType), StructField("", LongType))) + ArrowUtils.toArrowSchema(schema) + } + + { + val toUnsafe = UnsafeProjection.create(javaParamsSchema) + toUnsafe.initialize(Option(TaskContext.get()).map(_.partitionId()).getOrElse(0)) + toUnsafe + } + + def update(values: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = values.foldLeft(initialize()) { (buffer, input) => + updater(joiner(buffer, input)) + } + buffer.copy() + } + + def update( + rows: Array[InternalRow], + importIdxFFIArrayPtr: Long, + importBatchFFIArrayPtr: Long): Array[InternalRow] = { + Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => + Using.resources( + VectorSchemaRoot.create(inputSchema, batchAllocator), + VectorSchemaRoot.create(indexSchema, batchAllocator), + ArrowArray.wrap(importIdxFFIArrayPtr), + ArrowArray.wrap(importBatchFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => + // import into params root + Data.importIntoVectorSchemaRoot(batchAllocator, inputArray, inputRoot, dictionaryProvider) + val batch = ColumnarHelper.rootAsBatch(inputRoot) + val inputRows = ColumnarHelper.batchAsRowIter(batch).toArray + + Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) + val fieldVectors = idxRoot.getFieldVectors.asScala + val rowIdxVector = fieldVectors.head.asInstanceOf[LongVector] + val inputIdxVector = fieldVectors(1).asInstanceOf[LongVector] + + assert( + rowIdxVector.length() == inputIdxVector.length(), + s"Error: SparkUDAFWrapperContext update error Vectors have different lengths.") + + for (i <- 0 until rowIdxVector.length()) { + val row = rows(rowIdxVector.get(i).toInt) + val input = inputRows(inputIdxVector.get(i).toInt) + val joiner = new JoinedRow + if (row.numFields == 0) { + rows(rowIdxVector.get(i).toInt) = updater(joiner(initialize(), input)) + } else { + rows(rowIdxVector.get(i).toInt) = updater(joiner(row, input)) + } + } + + rows + } + } + } + + def merge( + rows: Array[InternalRow], + mergeRows: Array[InternalRow], + importIdxFFIArrayPtr: Long): Array[InternalRow] = { + Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => + Using.resources( + VectorSchemaRoot.create(indexSchema, batchAllocator), + ArrowArray.wrap(importIdxFFIArrayPtr)) { (idxRoot, idxArray) => + Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) + val fieldVectors = idxRoot.getFieldVectors.asScala + val rowIdxVector = fieldVectors.head.asInstanceOf[LongVector] + val mergeIdxVector = fieldVectors(1).asInstanceOf[LongVector] + + assert( + rowIdxVector.length() == mergeIdxVector.length(), + s"Error: SparkUDAFWrapperContext update error Vectors have different lengths.") + + for (i <- 0 until rowIdxVector.length()) { + val row = rows(rowIdxVector.get(i).toInt) + val mergeRow = mergeRows(mergeIdxVector.get(i).toInt) + val joiner = new JoinedRow + if (row.numFields == 0) { + rows(rowIdxVector.get(i).toInt) = merger(joiner(initialize(), mergeRow)) + } else { + rows(rowIdxVector.get(i).toInt) = merger(joiner(row, mergeRow)) + } + } + + rows + } + } + } + + def eval( + rows: Array[InternalRow], + importIdxFFIArrayPtr: Long, + exportFFIArrayPtr: Long): Unit = { + Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => + Using.resources( + VectorSchemaRoot.create(indexSchema, batchAllocator), + VectorSchemaRoot.create(inputSchema, batchAllocator), + ArrowArray.wrap(importIdxFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { (idxRoot, outputRoot, idxArray, exportArray) => + Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) + val fieldVectors = idxRoot.getFieldVectors.asScala + val rowIdxVector = fieldVectors.head.asInstanceOf[LongVector] + + // evaluate expression and write to output root + val outputWriter = ArrowWriter.create(outputRoot) + for (i <- 0 until rowIdxVector.length()) { + val row = rows(rowIdxVector.get(i).toInt) + outputWriter.write(evaluator(row)) + } + outputWriter.finish() + + // export to output using root allocator + Data.exportVectorSchemaRoot( + ArrowUtils.rootAllocator, + outputRoot, + dictionaryProvider, + exportArray) + } + } + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala new file mode 100644 index 000000000..a535b6493 --- /dev/null +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -0,0 +1,188 @@ +/* + * Copyright 2022 The Blaze Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.blaze + +import org.apache.arrow.c.{ArrowArray, Data} +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.dictionary.DictionaryProvider +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.util.Utils +import org.apache.spark.internal.Logging +import org.apache.spark.sql.blaze.util.Using +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.UnsafeRowSerializer +import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, LongType, StructField, StructType} +import org.apache.arrow.flatbuf +import com.google.flatbuffers.{IntVector, LongVector} +import org.apache.spark.sql.Row + +import scala.collection.JavaConverters._ +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.lang.reflect.Field + +object UnsafeRowsWrapper extends Logging { + + private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() + private val idxSchema = { + val schema = StructType(Seq(StructField("", LongType, nullable = false))) + ArrowUtils.toArrowSchema(schema) + } + + private val byteSchema = { + val schema = StructType(Seq(StructField("", BinaryType, nullable = false))) + ArrowUtils.toArrowSchema(schema) + } + + private val deserializeSchema = { + val schema = StructType( + Seq( + StructField("", BinaryType, nullable = false), + StructField("", IntegerType, nullable = false))) + ArrowUtils.toArrowSchema(schema) + } + + private val offsetSchema = { + val schema = StructType(Seq(StructField("", IntegerType, nullable = false))) + ArrowUtils.toArrowSchema(schema) + } + + private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { + val converter = unsafeRowConverter(schema) + converter(row) + } + + private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = { + val converter = UnsafeProjection.create(schema) + (row: Row) => { + converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]) + } + } + + def serialize( + unsafeRows: Array[InternalRow], + numFields: Int, + importFFIArrayPtr: Long, + exportFFIArrayPtr: Long): Unit = { + Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => + Using.resources( + VectorSchemaRoot.create(byteSchema, batchAllocator), + VectorSchemaRoot.create(idxSchema, batchAllocator), + ArrowArray.wrap(importFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { + (outputRoot, paramsRoot, importArray, exportArray) => + // import into params root + Data.importIntoVectorSchemaRoot( + batchAllocator, + importArray, + paramsRoot, + dictionaryProvider) + val idxArray = paramsRoot.getFieldVectors.asScala.head.asInstanceOf[LongVector]; + + val serializer = new UnsafeRowSerializer(numFields).newInstance() + val outputWriter = ArrowWriter.create(outputRoot) + for (idx <- 0 until idxArray.length()) { + val internalRow = unsafeRows(idx) + Utils.tryWithResource(new ByteArrayOutputStream()) { baos => + val serializerStream = serializer.serializeStream(baos) + serializerStream.writeValue(internalRow) + val bytes = baos.toByteArray + outputWriter.write(toUnsafeRow(Row(bytes), Array(BinaryType))) + } + } + + outputWriter.finish() + + // export to output using root allocator + Data.exportVectorSchemaRoot( + ArrowUtils.rootAllocator, + outputRoot, + dictionaryProvider, + exportArray) + } + } + } + + def deserialize( + numFields: Int, + importFFIArrayPtr: Long, + exportFFIArrayPtr: Long): Array[InternalRow] = { + + Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => + Using.resources( + VectorSchemaRoot.create(deserializeSchema, batchAllocator), + VectorSchemaRoot.create(offsetSchema, batchAllocator), + ArrowArray.wrap(importFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { + (paramsRoot, outputRoot, importArray, exportArray) => + Data.importIntoVectorSchemaRoot( + batchAllocator, + importArray, + paramsRoot, + dictionaryProvider) + val fieldVectors = paramsRoot.getFieldVectors.asScala + val binaryVector = fieldVectors.head.asInstanceOf[flatbuf.Binary.Vector]; + val intVector = fieldVectors(1).asInstanceOf[IntVector] + + assert( + binaryVector.length() == intVector.length(), + s"Error: UnsafeRowsWrapper deserialize error Vectors have different lengths.") + + val deserializer = new UnsafeRowSerializer(numFields).newInstance() + val internalRowsArray = new Array[InternalRow](binaryVector.length()) + val outputWriter = ArrowWriter.create(outputRoot) + for (i <- 0 until binaryVector.length()) { + val binaryRow = binaryVector.get(i) + val offset = intVector.get(i) + val bytes = binaryRow.getByteBuffer.array() + val internalRow: InternalRow = Utils.tryWithResource( + new ByteArrayInputStream(bytes, offset, bytes.length - offset)) { bais => + val unsafeRow = + deserializer.deserializeStream(bais).readValue().asInstanceOf[UnsafeRow] + // get offset use reflect + val field: Field = classOf[ByteArrayInputStream].getDeclaredField("pos") + field.setAccessible(true) + val position = field.getInt(bais) + outputWriter.write(toUnsafeRow(Row(position), Array(IntegerType))) + unsafeRow + } + internalRowsArray(i) = internalRow + } + + outputWriter.finish() + + // export to output using root allocator + Data.exportVectorSchemaRoot( + ArrowUtils.rootAllocator, + outputRoot, + dictionaryProvider, + exportArray) + + internalRowsArray + } + } + } + + def getRowNum(unsafeRows: Array[InternalRow]): Int = { + unsafeRows.length + } + + def getNullObject(rowNum: Int): Array[InternalRow] = { + Array.fill(rowNum)(InternalRow.empty) + } + +} From 3dd8678b10063e16e017f5a8da5c7e9aa20caa52 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Wed, 5 Feb 2025 15:37:43 +0800 Subject: [PATCH 02/17] update DeclarativeAggregate udaf --- .../blaze-jni-bridge/src/jni_bridge.rs | 14 +-- .../datafusion-ext-plans/src/agg/agg.rs | 16 +-- .../src/agg/spark_hdaf_wrapper.rs | 104 ++++++++++++++---- .../sql/blaze/UnsafeRowsWrapperUtils.java | 5 +- .../sql/blaze/SparkUDAFWrapperContext.scala | 61 ++++------ .../spark/sql/blaze/UnsafeRowsWrapper.scala | 22 ++-- 6 files changed, 125 insertions(+), 97 deletions(-) diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index 24c8e4e6f..31fa2b858 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -1195,17 +1195,17 @@ impl<'a> SparkUDAFWrapperContext<'a> { method_update: env.get_method_id( class, "update", - "([Lorg/apache/spark/sql/catalyst/InternalRow;JJ;)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, + "([Lorg/apache/spark/sql/catalyst/InternalRow;JJ)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, method_update_ret: ReturnType::Object, method_merge: env.get_method_id( class, "merge", - "([Lorg/apache/spark/sql/catalyst/InternalRow;[Lorg/apache/spark/sql/catalyst/InternalRow;J;)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, + "([Lorg/apache/spark/sql/catalyst/InternalRow;[Lorg/apache/spark/sql/catalyst/InternalRow;J)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, method_merge_ret: ReturnType::Object, method_eval: env.get_method_id( class, "eval", - "([Lorg/apache/spark/sql/catalyst/InternalRow;JJ;)V")?, + "([Lorg/apache/spark/sql/catalyst/InternalRow;JJ)V")?, method_eval_ret: ReturnType::Primitive(Primitive::Void), }) @@ -1250,7 +1250,7 @@ pub struct BlazeUnsafeRowsWrapperUtils<'a> { pub method_create_ret: ReturnType, } impl<'a> BlazeUnsafeRowsWrapperUtils<'a> { - pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils;"; + pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils"; pub fn new(env: &JNIEnv<'a>) -> JniResult> { let class = get_global_jclass(env, Self::SIG_TYPE)?; @@ -1265,19 +1265,19 @@ impl<'a> BlazeUnsafeRowsWrapperUtils<'a> { method_deserialize: env.get_static_method_id( class, "deserialize", - "(IJJ;)[Lorg/apache/spark/sql/catalyst/InternalRow;", + "(IJJ)[Lorg/apache/spark/sql/catalyst/InternalRow;", )?, method_deserialize_ret: ReturnType::Object, method_num: env.get_static_method_id( class, "getRowNum", - "([Lorg/apache/spark/sql/catalyst/InternalRow;)I;", + "([Lorg/apache/spark/sql/catalyst/InternalRow;)I", )?, method_num_ret: ReturnType::Primitive(Primitive::Int), method_create: env.get_static_method_id( class, "getEmptyObject", - "(I;)[Lorg/apache/spark/sql/catalyst/InternalRow;", + "(I)[Lorg/apache/spark/sql/catalyst/InternalRow;", )?, method_create_ret: ReturnType::Object, }) diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index c430e83c3..68521e352 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -15,7 +15,7 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use arrow::{ - array::{Array, ArrayRef, AsArray, Int64Array, Int64Builder, RecordBatch}, + array::{Array, ArrayRef, AsArray, Int32Array, Int32Builder, RecordBatch}, datatypes::{DataType, Int64Type, Schema, SchemaRef}, }; use datafusion::{common::Result, physical_expr::PhysicalExpr}; @@ -78,34 +78,34 @@ impl IdxSelection<'_> { } } - pub fn to_int64_array(&self) -> Int64Array { - let mut builder = Int64Builder::with_capacity(self.len()); + pub fn to_int32_array(&self) -> Int32Array { + let mut builder = Int32Builder::with_capacity(self.len()); match self { IdxSelection::Single(idx) => { - builder.append_value(*idx as i64); + builder.append_value(*idx as i32); } IdxSelection::Indices(indices) => { for &idx in *indices { - builder.append_value(idx as i64); + builder.append_value(idx as i32); } } IdxSelection::IndicesU32(indices_u32) => { for &idx in *indices_u32 { - builder.append_value(idx as i64); + builder.append_value(idx as i32); } } IdxSelection::Range(start, end) => { for idx in *start..=*end { - builder.append_value(idx as i64); + builder.append_value(idx as i32); } } } let primitive_array = builder.finish(); primitive_array .as_any() - .downcast_ref::() + .downcast_ref::() .cloned() .unwrap() } diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs index 77ce80805..2489e8fc7 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs @@ -21,7 +21,7 @@ use std::{ use arrow::{ array::{ as_struct_array, make_array, Array, ArrayAccessor, ArrayRef, AsArray, BinaryArray, Datum, - Int32Array, StructArray, + Int32Array, Int32Builder, StructArray, }, buffer::NullBuffer, datatypes::{DataType, Field, Schema, SchemaRef}, @@ -45,6 +45,7 @@ use crate::{ acc::{AccColumn, AccColumnRef}, agg::{Agg, IdxSelection}, }, + idx_for_zipped, memmgr::spill::{SpillCompressedReader, SpillCompressedWriter}, }; @@ -177,6 +178,29 @@ impl Agg for SparkUDAFWrapper { &RecordBatchOptions::new().with_row_count(Some(params[0].len())), )?; + let max_len = std::cmp::max(acc_idx.len(), partial_arg_idx.len()); + let mut acc_idx_builder = Int32Builder::with_capacity(max_len); + let mut partial_arg_idx_builder = Int32Builder::with_capacity(max_len); + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + acc_idx_builder.append_value(acc_idx as i32); + partial_arg_idx_builder.append_value(partial_arg_idx as i32); + } + } + let acc_idx = acc_idx_builder + .finish() + .as_any() + .downcast_ref::() + .cloned() + .unwrap(); + + let partial_arg_idx = partial_arg_idx_builder + .finish() + .as_any() + .downcast_ref::() + .cloned() + .unwrap(); + accs.obj = partial_update_udaf( self.jcontext()?, params_batch, @@ -199,6 +223,29 @@ impl Agg for SparkUDAFWrapper { let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); let merging_accs = downcast_any!(merging_accs, mut AccUnsafeRowsColumn).unwrap(); + let max_len = std::cmp::max(acc_idx.len(), merging_acc_idx.len()); + let mut acc_idx_builder = Int32Builder::with_capacity(max_len); + let mut merging_acc_idx_builder = Int32Builder::with_capacity(max_len); + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + acc_idx_builder.append_value(acc_idx as i32); + merging_acc_idx_builder.append_value(merging_acc_idx as i32); + } + } + let acc_idx = acc_idx_builder + .finish() + .as_any() + .downcast_ref::() + .cloned() + .unwrap(); + + let merging_acc_idx = merging_acc_idx_builder + .finish() + .as_any() + .downcast_ref::() + .cloned() + .unwrap(); + accs.obj = partial_merge_udaf( self.jcontext()?, accs.obj.clone(), @@ -255,9 +302,9 @@ impl AccColumn for AccUnsafeRowsColumn { } fn freeze_to_rows(&self, idx: IdxSelection<'_>, array: &mut [Vec]) -> Result<()> { - let field = Arc::new(Field::new("", DataType::Int64, false)); - let idx64 = idx.to_int64_array().into_data(); - let struct_array = StructArray::from(vec![(field, make_array(idx64))]); + let field = Arc::new(Field::new("", DataType::Int32, false)); + let idx32 = idx.to_int32_array().into_data(); + let struct_array = StructArray::from(vec![(field, make_array(idx32))]); let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); let mut import_ffi_array = FFI_ArrowArray::empty(); jni_call_static!( @@ -276,7 +323,7 @@ impl AccColumn for AccUnsafeRowsColumn { let result_struct = import_struct_array.as_struct(); let binary_array = result_struct - .column(1) + .column(0) .as_any() .downcast_ref::() .ok_or_else(|| DataFusionError::Execution("Expected a BinaryArray".to_string()))?; @@ -295,7 +342,7 @@ impl AccColumn for AccUnsafeRowsColumn { fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> { let fields = Fields::from(vec![ Field::new("", DataType::Binary, false), - Field::new("", DataType::Int64, false), + Field::new("", DataType::Int32, false), ]); let binary_values = array.iter().map(|&data| data).collect(); let offsets_i32 = offsets @@ -354,16 +401,23 @@ fn partial_update_udaf( jcontext: GlobalRef, params_batch: RecordBatch, accs: GlobalRef, - acc_idx: IdxSelection<'_>, - partial_arg_idx: IdxSelection<'_>, + acc_idx: Int32Array, + partial_arg_idx: Int32Array, ) -> Result { - let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int64, false)); - let partial_arg_idx_field = Arc::new(Field::new("partial_arg_idx", DataType::Int64, false)); - let acc_idx = acc_idx.to_int64_array().into_data(); - let partial_arg_idx = partial_arg_idx.to_int64_array().into_data(); + let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int32, false)); + let partial_arg_idx_field = Arc::new(Field::new("partial_arg_idx", DataType::Int32, false)); + + log::info!( + "acc_idx length {} partial_arg_idx len {}", + acc_idx.len(), + partial_arg_idx.len() + ); let struct_array = StructArray::from(vec![ - (acc_idx_field.clone(), make_array(acc_idx)), - (partial_arg_idx_field.clone(), make_array(partial_arg_idx)), + (acc_idx_field.clone(), make_array(acc_idx.into_data())), + ( + partial_arg_idx_field.clone(), + make_array(partial_arg_idx.into_data()), + ), ]); let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); @@ -383,16 +437,18 @@ fn partial_merge_udaf( jcontext: GlobalRef, accs: GlobalRef, merging_accs: GlobalRef, - acc_idx: IdxSelection<'_>, - merging_acc_idx: IdxSelection<'_>, + acc_idx: Int32Array, + merging_acc_idx: Int32Array, ) -> Result { - let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int64, false)); - let merging_acc_idx_field = Arc::new(Field::new("merging_acc_idx", DataType::Int64, false)); - let acc_idx = acc_idx.to_int64_array().into_data(); - let merging_acc_idx = merging_acc_idx.to_int64_array().into_data(); + let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int32, false)); + let merging_acc_idx_field = Arc::new(Field::new("merging_acc_idx", DataType::Int32, false)); + let struct_array = StructArray::from(vec![ - (acc_idx_field.clone(), make_array(acc_idx)), - (merging_acc_idx_field.clone(), make_array(merging_acc_idx)), + (acc_idx_field.clone(), make_array(acc_idx.into_data())), + ( + merging_acc_idx_field.clone(), + make_array(merging_acc_idx.into_data()), + ), ]); let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); @@ -411,8 +467,8 @@ fn final_merge_udaf( acc_idx: IdxSelection<'_>, result_schema: SchemaRef, ) -> Result { - let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int64, false)); - let acc_idx = acc_idx.to_int64_array().into_data(); + let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int32, false)); + let acc_idx = acc_idx.to_int32_array().into_data(); let struct_array = StructArray::from(vec![(acc_idx_field.clone(), make_array(acc_idx))]); let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); let mut import_ffi_array = FFI_ArrowArray::empty(); diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java index a3c2c76b8..f51423393 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java @@ -17,16 +17,15 @@ import org.apache.spark.sql.catalyst.InternalRow; -// for jni_bridge usage public class UnsafeRowsWrapperUtils { public static void serialize( - InternalRow[] unsafeRows, int numFields, Long importFFIArrayPtr, Long exportFFIArrayPtr) { + InternalRow[] unsafeRows, int numFields, long importFFIArrayPtr, long exportFFIArrayPtr) { UnsafeRowsWrapper$.MODULE$.serialize(unsafeRows, numFields, importFFIArrayPtr, exportFFIArrayPtr); } - public static InternalRow[] deserialize(int numFields, Long importFFIArrayPtr, Long exportFFIArrayPtr) { + public static InternalRow[] deserialize(int numFields, long importFFIArrayPtr, long exportFFIArrayPtr) { return UnsafeRowsWrapper$.MODULE$.deserialize(numFields, importFFIArrayPtr, exportFFIArrayPtr); } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 6d173b77a..af632afc3 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.blaze import scala.collection.JavaConverters._ import org.apache.arrow.c.{ArrowArray, Data} -import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.{IntVector, VectorSchemaRoot} import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider import org.apache.spark.TaskContext @@ -25,11 +25,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.Nondeterministic import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} import org.apache.spark.sql.types.{LongType, StructField, StructType} -import com.google.flatbuffers.LongVector import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, MutableProjection} import org.apache.spark.sql.catalyst.expressions.AttributeReference @@ -82,20 +81,6 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { ArrowUtils.toArrowSchema(schema) } - { - val toUnsafe = UnsafeProjection.create(javaParamsSchema) - toUnsafe.initialize(Option(TaskContext.get()).map(_.partitionId()).getOrElse(0)) - toUnsafe - } - - def update(values: InternalRow*): InternalRow = { - val joiner = new JoinedRow - val buffer = values.foldLeft(initialize()) { (buffer, input) => - updater(joiner(buffer, input)) - } - buffer.copy() - } - def update( rows: Array[InternalRow], importIdxFFIArrayPtr: Long, @@ -113,21 +98,17 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors.head.asInstanceOf[LongVector] - val inputIdxVector = fieldVectors(1).asInstanceOf[LongVector] + val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] + val inputIdxVector = fieldVectors(1).asInstanceOf[IntVector] - assert( - rowIdxVector.length() == inputIdxVector.length(), - s"Error: SparkUDAFWrapperContext update error Vectors have different lengths.") - - for (i <- 0 until rowIdxVector.length()) { - val row = rows(rowIdxVector.get(i).toInt) - val input = inputRows(inputIdxVector.get(i).toInt) + for (i <- 0 until idxRoot.getRowCount) { + val row = rows(rowIdxVector.get(i)) + val input = inputRows(inputIdxVector.get(i)) val joiner = new JoinedRow if (row.numFields == 0) { - rows(rowIdxVector.get(i).toInt) = updater(joiner(initialize(), input)) + rows(rowIdxVector.get(i)) = updater(joiner(initialize(), input)) } else { - rows(rowIdxVector.get(i).toInt) = updater(joiner(row, input)) + rows(rowIdxVector.get(i)) = updater(joiner(row, input)) } } @@ -146,21 +127,17 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { ArrowArray.wrap(importIdxFFIArrayPtr)) { (idxRoot, idxArray) => Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors.head.asInstanceOf[LongVector] - val mergeIdxVector = fieldVectors(1).asInstanceOf[LongVector] - - assert( - rowIdxVector.length() == mergeIdxVector.length(), - s"Error: SparkUDAFWrapperContext update error Vectors have different lengths.") + val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] + val mergeIdxVector = fieldVectors(1).asInstanceOf[IntVector] - for (i <- 0 until rowIdxVector.length()) { - val row = rows(rowIdxVector.get(i).toInt) - val mergeRow = mergeRows(mergeIdxVector.get(i).toInt) + for (i <- 0 until idxRoot.getRowCount) { + val row = rows(rowIdxVector.get(i)) + val mergeRow = mergeRows(mergeIdxVector.get(i)) val joiner = new JoinedRow if (row.numFields == 0) { - rows(rowIdxVector.get(i).toInt) = merger(joiner(initialize(), mergeRow)) + rows(rowIdxVector.get(i)) = merger(joiner(initialize(), mergeRow)) } else { - rows(rowIdxVector.get(i).toInt) = merger(joiner(row, mergeRow)) + rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)) } } @@ -181,12 +158,12 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { ArrowArray.wrap(exportFFIArrayPtr)) { (idxRoot, outputRoot, idxArray, exportArray) => Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors.head.asInstanceOf[LongVector] + val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] // evaluate expression and write to output root val outputWriter = ArrowWriter.create(outputRoot) - for (i <- 0 until rowIdxVector.length()) { - val row = rows(rowIdxVector.get(i).toInt) + for (i <- 0 until idxRoot.getRowCount) { + val row = rows(rowIdxVector.get(i)) outputWriter.write(evaluator(row)) } outputWriter.finish() diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala index a535b6493..dfd49a80b 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -16,7 +16,7 @@ package org.apache.spark.sql.blaze import org.apache.arrow.c.{ArrowArray, Data} -import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.{IntVector, VectorSchemaRoot} import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -28,7 +28,6 @@ import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, LongType, StructField, StructType} import org.apache.arrow.flatbuf -import com.google.flatbuffers.{IntVector, LongVector} import org.apache.spark.sql.Row import scala.collection.JavaConverters._ @@ -91,12 +90,11 @@ object UnsafeRowsWrapper extends Logging { importArray, paramsRoot, dictionaryProvider) - val idxArray = paramsRoot.getFieldVectors.asScala.head.asInstanceOf[LongVector]; - + val idxArray = paramsRoot.getFieldVectors.asScala.head.asInstanceOf[IntVector] val serializer = new UnsafeRowSerializer(numFields).newInstance() val outputWriter = ArrowWriter.create(outputRoot) - for (idx <- 0 until idxArray.length()) { - val internalRow = unsafeRows(idx) + for (idx <- 0 until paramsRoot.getRowCount) { + val internalRow = unsafeRows(idxArray.get(idx)) Utils.tryWithResource(new ByteArrayOutputStream()) { baos => val serializerStream = serializer.serializeStream(baos) serializerStream.writeValue(internalRow) @@ -138,14 +136,10 @@ object UnsafeRowsWrapper extends Logging { val binaryVector = fieldVectors.head.asInstanceOf[flatbuf.Binary.Vector]; val intVector = fieldVectors(1).asInstanceOf[IntVector] - assert( - binaryVector.length() == intVector.length(), - s"Error: UnsafeRowsWrapper deserialize error Vectors have different lengths.") - val deserializer = new UnsafeRowSerializer(numFields).newInstance() - val internalRowsArray = new Array[InternalRow](binaryVector.length()) + val internalRowsArray = new Array[InternalRow](paramsRoot.getRowCount) val outputWriter = ArrowWriter.create(outputRoot) - for (i <- 0 until binaryVector.length()) { + for (i <- 0 until paramsRoot.getRowCount) { val binaryRow = binaryVector.get(i) val offset = intVector.get(i) val bytes = binaryRow.getByteBuffer.array() @@ -182,7 +176,9 @@ object UnsafeRowsWrapper extends Logging { } def getNullObject(rowNum: Int): Array[InternalRow] = { - Array.fill(rowNum)(InternalRow.empty) + Array.fill(rowNum) { + new UnsafeRow(0) + } } } From 9d069e524dfa58c9996d98b78375c23784cc7a52 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Thu, 6 Feb 2025 16:26:39 +0800 Subject: [PATCH 03/17] update DeclarativeAggregate udaf --- .../src/agg/spark_hdaf_wrapper.rs | 44 ++++++++++------- .../sql/blaze/UnsafeRowsWrapperUtils.java | 1 - .../sql/blaze/SparkUDAFWrapperContext.scala | 27 ++++++----- .../spark/sql/blaze/UnsafeRowsWrapper.scala | 47 ++++++++++++------- 4 files changed, 70 insertions(+), 49 deletions(-) diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs index 2489e8fc7..22c431e3e 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs @@ -23,12 +23,10 @@ use arrow::{ as_struct_array, make_array, Array, ArrayAccessor, ArrayRef, AsArray, BinaryArray, Datum, Int32Array, Int32Builder, StructArray, }, - buffer::NullBuffer, datatypes::{DataType, Field, Schema, SchemaRef}, ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, record_batch::{RecordBatch, RecordBatchOptions}, }; -use arrow_schema::Fields; use blaze_jni_bridge::{ jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, }; @@ -156,6 +154,7 @@ impl Agg for SparkUDAFWrapper { partial_arg_idx: IdxSelection<'_>, batch_schema: SchemaRef, ) -> Result<()> { + log::info!("start partial update"); let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); let params = partial_args.to_vec(); @@ -220,6 +219,7 @@ impl Agg for SparkUDAFWrapper { merging_accs: &mut AccColumnRef, merging_acc_idx: IdxSelection<'_>, ) -> Result<()> { + log::info!("start partial merge"); let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); let merging_accs = downcast_any!(merging_accs, mut AccUnsafeRowsColumn).unwrap(); @@ -259,6 +259,7 @@ impl Agg for SparkUDAFWrapper { } fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + log::info!("start final merge"); let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); final_merge_udaf( self.jcontext()?, @@ -332,6 +333,7 @@ impl AccColumn for AccUnsafeRowsColumn { if binary_array.is_valid(i) { let bytes = binary_array.value(i).to_vec(); array[i] = bytes; + log::info!("freeze arrary {} : {:?}", i, array[i]); } else { log::warn!("AccUnsafeRowsColumn::freeze_to_rows : binary_array null error") } @@ -340,20 +342,33 @@ impl AccColumn for AccUnsafeRowsColumn { } fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> { - let fields = Fields::from(vec![ - Field::new("", DataType::Binary, false), - Field::new("", DataType::Int32, false), - ]); + + log::info!("unfreeze array {:?}", array.clone()); + log::info!("unfreeze offsets {:?}", offsets); let binary_values = array.iter().map(|&data| data).collect(); let offsets_i32 = offsets .iter() .map(|data| *data as i32) .collect::>(); - let offsets_array = Int32Array::from(offsets_i32); - let binary_array = BinaryArray::from_vec(binary_values); - let nulls = Some(NullBuffer::from_iter([false, false])); - let values = vec![Arc::new(binary_array) as _, Arc::new(offsets_array) as _]; - let struct_array = StructArray::new(fields, values, nulls); + let offsets_array = Int32Array::from(offsets_i32) + .as_any() + .downcast_ref::() + .cloned() + .unwrap(); + let binary_array = BinaryArray::from_vec(binary_values) + .as_any() + .downcast_ref::() + .cloned() + .unwrap(); + + let binary_field = Arc::new(Field::new("", DataType::Binary, false)); + let offsets_field = Arc::new(Field::new("", DataType::Int32, false)); + + let struct_array = StructArray::from(vec![ + (binary_field.clone(), make_array(binary_array.into_data())), + (offsets_field.clone(), make_array(offsets_array.into_data())), + ]); + let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); let mut import_ffi_array = FFI_ArrowArray::empty(); let rows = jni_call_static!( @@ -374,7 +389,7 @@ impl AccColumn for AccUnsafeRowsColumn { let result_struct = import_struct_array.as_struct(); let int32array = result_struct - .column(1) + .column(0) .as_any() .downcast_ref::() .ok_or_else(|| DataFusionError::Execution("Expected a Int32Array".to_string()))?; @@ -407,11 +422,6 @@ fn partial_update_udaf( let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int32, false)); let partial_arg_idx_field = Arc::new(Field::new("partial_arg_idx", DataType::Int32, false)); - log::info!( - "acc_idx length {} partial_arg_idx len {}", - acc_idx.len(), - partial_arg_idx.len() - ); let struct_array = StructArray::from(vec![ (acc_idx_field.clone(), make_array(acc_idx.into_data())), ( diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java index f51423393..8fa075f6a 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java @@ -17,7 +17,6 @@ import org.apache.spark.sql.catalyst.InternalRow; - public class UnsafeRowsWrapperUtils { public static void serialize( diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index af632afc3..80629ed32 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -25,12 +25,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate -import org.apache.spark.sql.catalyst.expressions.Nondeterministic +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, JoinedRow, Nondeterministic, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.types.{LongType, StructField, StructType} -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, MutableProjection} -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import java.nio.ByteBuffer @@ -53,17 +51,17 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { case _ => } - private lazy val initializer = MutableProjection.create(expr.initialValues) + private lazy val initializer = UnsafeProjection.create(expr.initialValues) private lazy val updater = - MutableProjection.create(expr.updateExpressions, expr.aggBufferAttributes ++ inputAttributes) + UnsafeProjection.create(expr.updateExpressions, expr.aggBufferAttributes ++ inputAttributes) - private lazy val merger = MutableProjection.create( + private lazy val merger = UnsafeProjection.create( expr.mergeExpressions, expr.aggBufferAttributes ++ expr.inputAggBufferAttributes) private lazy val evaluator = - MutableProjection.create(expr.evaluateExpression :: Nil, expr.aggBufferAttributes) + UnsafeProjection.create(expr.evaluateExpression :: Nil, expr.aggBufferAttributes) private def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() @@ -77,7 +75,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } private val indexSchema = { - val schema = StructType(Seq(StructField("", LongType), StructField("", LongType))) + val schema = StructType(Seq(StructField("", IntegerType), StructField("", IntegerType))) ArrowUtils.toArrowSchema(schema) } @@ -85,12 +83,13 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { rows: Array[InternalRow], importIdxFFIArrayPtr: Long, importBatchFFIArrayPtr: Long): Array[InternalRow] = { + logInfo("start partial update in scalar!") Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( VectorSchemaRoot.create(inputSchema, batchAllocator), VectorSchemaRoot.create(indexSchema, batchAllocator), - ArrowArray.wrap(importIdxFFIArrayPtr), - ArrowArray.wrap(importBatchFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => + ArrowArray.wrap(importBatchFFIArrayPtr), + ArrowArray.wrap(importIdxFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => // import into params root Data.importIntoVectorSchemaRoot(batchAllocator, inputArray, inputRoot, dictionaryProvider) val batch = ColumnarHelper.rootAsBatch(inputRoot) @@ -111,7 +110,8 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { rows(rowIdxVector.get(i)) = updater(joiner(row, input)) } } - + logInfo(s"update rows num: ${rows.length}, rows.fieldnum:${rows(0).numFields}") + logInfo(s"row 0: ${rows(0).toString}") rows } } @@ -121,6 +121,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { rows: Array[InternalRow], mergeRows: Array[InternalRow], importIdxFFIArrayPtr: Long): Array[InternalRow] = { + logInfo("start merge in scalar!!") Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( VectorSchemaRoot.create(indexSchema, batchAllocator), @@ -140,7 +141,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)) } } - + logInfo("finish merge in scalar!!") rows } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala index dfd49a80b..33f1dce4b 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -16,7 +16,7 @@ package org.apache.spark.sql.blaze import org.apache.arrow.c.{ArrowArray, Data} -import org.apache.arrow.vector.{IntVector, VectorSchemaRoot} +import org.apache.arrow.vector.{VarBinaryVector, IntVector, VectorSchemaRoot} import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -26,8 +26,7 @@ import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, LongType, StructField, StructType} -import org.apache.arrow.flatbuf +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StructField, StructType} import org.apache.spark.sql.Row import scala.collection.JavaConverters._ @@ -38,7 +37,7 @@ object UnsafeRowsWrapper extends Logging { private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() private val idxSchema = { - val schema = StructType(Seq(StructField("", LongType, nullable = false))) + val schema = StructType(Seq(StructField("", IntegerType, nullable = false))) ArrowUtils.toArrowSchema(schema) } @@ -93,12 +92,17 @@ object UnsafeRowsWrapper extends Logging { val idxArray = paramsRoot.getFieldVectors.asScala.head.asInstanceOf[IntVector] val serializer = new UnsafeRowSerializer(numFields).newInstance() val outputWriter = ArrowWriter.create(outputRoot) + logInfo(s"freeze unsaferows num: ${unsafeRows.length}") + logInfo(s"freeze idxArray $idxArray") for (idx <- 0 until paramsRoot.getRowCount) { val internalRow = unsafeRows(idxArray.get(idx)) + logInfo(s"freeze unsafe row : ${internalRow.toString}") Utils.tryWithResource(new ByteArrayOutputStream()) { baos => val serializerStream = serializer.serializeStream(baos) serializerStream.writeValue(internalRow) + serializerStream.close() val bytes = baos.toByteArray + logInfo(s"write bytes : ${java.util.Arrays.toString(bytes)}") outputWriter.write(toUnsafeRow(Row(bytes), Array(BinaryType))) } } @@ -133,28 +137,35 @@ object UnsafeRowsWrapper extends Logging { paramsRoot, dictionaryProvider) val fieldVectors = paramsRoot.getFieldVectors.asScala - val binaryVector = fieldVectors.head.asInstanceOf[flatbuf.Binary.Vector]; + val binaryVector = fieldVectors.head.asInstanceOf[VarBinaryVector]; val intVector = fieldVectors(1).asInstanceOf[IntVector] val deserializer = new UnsafeRowSerializer(numFields).newInstance() val internalRowsArray = new Array[InternalRow](paramsRoot.getRowCount) val outputWriter = ArrowWriter.create(outputRoot) for (i <- 0 until paramsRoot.getRowCount) { - val binaryRow = binaryVector.get(i) + val bytes = binaryVector.get(i) val offset = intVector.get(i) - val bytes = binaryRow.getByteBuffer.array() - val internalRow: InternalRow = Utils.tryWithResource( - new ByteArrayInputStream(bytes, offset, bytes.length - offset)) { bais => - val unsafeRow = - deserializer.deserializeStream(bais).readValue().asInstanceOf[UnsafeRow] - // get offset use reflect - val field: Field = classOf[ByteArrayInputStream].getDeclaredField("pos") - field.setAccessible(true) - val position = field.getInt(bais) - outputWriter.write(toUnsafeRow(Row(position), Array(IntegerType))) - unsafeRow + logInfo(s"Reading bytes from offset: $offset, bytes length: ${bytes.length}") + if (bytes.length - offset > 0) { + val internalRow: InternalRow = Utils.tryWithResource( + new ByteArrayInputStream(bytes, offset, bytes.length - offset)) { bais => + val unsafeRow = + deserializer.deserializeStream(bais).readValue().asInstanceOf[UnsafeRow] + // get offset use reflect + val field: Field = classOf[ByteArrayInputStream].getDeclaredField("pos") + field.setAccessible(true) + val position = field.getInt(bais) + outputWriter.write(toUnsafeRow(Row(position), Array(IntegerType))) + logInfo(s"unsafe row numfield ${unsafeRow.numFields()}") + unsafeRow + } + internalRowsArray(i) = internalRow + } else { + internalRowsArray(i) = new UnsafeRow(0) + outputWriter.write(toUnsafeRow(Row(offset), Array(IntegerType))) } - internalRowsArray(i) = internalRow + } outputWriter.finish() From 36ce963145d3f89d0683dd28bcd14765437a9965 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Fri, 7 Feb 2025 15:53:28 +0800 Subject: [PATCH 04/17] update DeclarativeAggregate udaf --- .../src/agg/spark_hdaf_wrapper.rs | 8 +++- .../sql/blaze/SparkUDAFWrapperContext.scala | 48 ++++++++++++------- .../spark/sql/blaze/UnsafeRowsWrapper.scala | 22 +++++---- 3 files changed, 52 insertions(+), 26 deletions(-) diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs index 22c431e3e..45ecb038f 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs @@ -27,6 +27,7 @@ use arrow::{ ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, record_batch::{RecordBatch, RecordBatchOptions}, }; +use arrow::array::Float64Array; use blaze_jni_bridge::{ jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, }; @@ -431,7 +432,12 @@ fn partial_update_udaf( ]); let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); - let struct_array = StructArray::from(params_batch); + let struct_array = StructArray::from(params_batch.clone()); + log::info!("input struct_array {:?}", struct_array.fields()); + log::info!("batch{:?}", params_batch.column(0).into_data()); + let column0= params_batch.column(0).as_any().downcast_ref::().unwrap().values().to_vec(); + log::info!("column0 {:?}", column0); + let mut export_ffi_batch_array = FFI_ArrowArray::new(&struct_array.to_data()); let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).update( diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 80629ed32..675b201c4 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, JoinedRow, Nondeterministic, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} import java.nio.ByteBuffer @@ -69,16 +69,21 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { private val inputSchema = ArrowUtils.toArrowSchema(javaParamsSchema) - { - val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable))) - ArrowUtils.toArrowSchema(schema) - } private val indexSchema = { val schema = StructType(Seq(StructField("", IntegerType), StructField("", IntegerType))) ArrowUtils.toArrowSchema(schema) } + private val evalIndexSchema = { + val schema = StructType(Seq(StructField("", IntegerType))) + ArrowUtils.toArrowSchema(schema) + } + + val dataTypes: Seq[DataType] = expr.aggBufferAttributes.map(_.dataType) + + val inputTypes: Seq[DataType] = javaParamsSchema.map(_.dataType) + def update( rows: Array[InternalRow], importIdxFFIArrayPtr: Long, @@ -102,12 +107,21 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { for (i <- 0 until idxRoot.getRowCount) { val row = rows(rowIdxVector.get(i)) + logInfo(s"javaParamsSchema $javaParamsSchema") + logInfo(s"inputIdxVector: ${inputIdxVector}") + logInfo(s"inputRows.length ${inputRows.length}") val input = inputRows(inputIdxVector.get(i)) - val joiner = new JoinedRow - if (row.numFields == 0) { - rows(rowIdxVector.get(i)) = updater(joiner(initialize(), input)) - } else { - rows(rowIdxVector.get(i)) = updater(joiner(row, input)) + logInfo(s"is unsafe row: ${input.isInstanceOf[UnsafeRow]}") + logInfo(s"input numField ${input.numFields} $inputSchema") + if (input.numFields > 0) { + logInfo(s"input row: ${input.toSeq(inputTypes)}") + val joiner = new JoinedRow + if (row.numFields == 0) { + rows(rowIdxVector.get(i)) = updater(joiner(initialize(), input)) + } else { + rows(rowIdxVector.get(i)) = updater(joiner(row, input)) + } + logInfo(s"temp row 0: ${rows(0).toSeq(dataTypes)}") } } logInfo(s"update rows num: ${rows.length}, rows.fieldnum:${rows(0).numFields}") @@ -134,11 +148,13 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { for (i <- 0 until idxRoot.getRowCount) { val row = rows(rowIdxVector.get(i)) val mergeRow = mergeRows(mergeIdxVector.get(i)) - val joiner = new JoinedRow - if (row.numFields == 0) { - rows(rowIdxVector.get(i)) = merger(joiner(initialize(), mergeRow)) - } else { - rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)) + if (mergeRow.numFields > 0) { + val joiner = new JoinedRow + if (row.numFields == 0) { + rows(rowIdxVector.get(i)) = merger(joiner(initialize(), mergeRow)) + } else { + rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)) + } } } logInfo("finish merge in scalar!!") @@ -153,7 +169,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { exportFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( - VectorSchemaRoot.create(indexSchema, batchAllocator), + VectorSchemaRoot.create(evalIndexSchema, batchAllocator), VectorSchemaRoot.create(inputSchema, batchAllocator), ArrowArray.wrap(importIdxFFIArrayPtr), ArrowArray.wrap(exportFFIArrayPtr)) { (idxRoot, outputRoot, idxArray, exportArray) => diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala index 33f1dce4b..102300ea8 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -96,14 +96,18 @@ object UnsafeRowsWrapper extends Logging { logInfo(s"freeze idxArray $idxArray") for (idx <- 0 until paramsRoot.getRowCount) { val internalRow = unsafeRows(idxArray.get(idx)) - logInfo(s"freeze unsafe row : ${internalRow.toString}") - Utils.tryWithResource(new ByteArrayOutputStream()) { baos => - val serializerStream = serializer.serializeStream(baos) - serializerStream.writeValue(internalRow) - serializerStream.close() - val bytes = baos.toByteArray - logInfo(s"write bytes : ${java.util.Arrays.toString(bytes)}") - outputWriter.write(toUnsafeRow(Row(bytes), Array(BinaryType))) + if (internalRow.numFields == 0) { + outputWriter.write(toUnsafeRow(Row(Array.empty[Byte]), Array(BinaryType))) + } else { + logInfo(s"freeze unsafe row : ${internalRow.toString}") + Utils.tryWithResource(new ByteArrayOutputStream()) { baos => + val serializerStream = serializer.serializeStream(baos) + serializerStream.writeValue(internalRow) + serializerStream.close() + val bytes = baos.toByteArray + logInfo(s"write bytes : ${java.util.Arrays.toString(bytes)}") + outputWriter.write(toUnsafeRow(Row(bytes), Array(BinaryType))) + } } } @@ -157,7 +161,7 @@ object UnsafeRowsWrapper extends Logging { field.setAccessible(true) val position = field.getInt(bais) outputWriter.write(toUnsafeRow(Row(position), Array(IntegerType))) - logInfo(s"unsafe row numfield ${unsafeRow.numFields()}") + logInfo(s"unfreeze row ${unsafeRow.toString}") unsafeRow } internalRowsArray(i) = internalRow From 045c86c6d8388e35c4a991b141bb18d04725529e Mon Sep 17 00:00:00 2001 From: guoying06 Date: Tue, 11 Feb 2025 20:30:27 +0800 Subject: [PATCH 05/17] fix BindReference and deserialize --- .../datafusion-ext-plans/src/agg/agg.rs | 2 +- .../src/agg/spark_hdaf_wrapper.rs | 23 ++--- .../sql/blaze/UnsafeRowsWrapperUtils.java | 2 +- .../spark/sql/blaze/NativeConverters.scala | 23 ++--- .../sql/blaze/SparkUDAFWrapperContext.scala | 86 +++++++++++++------ .../spark/sql/blaze/UnsafeRowsWrapper.scala | 27 +++--- 6 files changed, 91 insertions(+), 72 deletions(-) diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 68521e352..4d2166acf 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -97,7 +97,7 @@ impl IdxSelection<'_> { } } IdxSelection::Range(start, end) => { - for idx in *start..=*end { + for idx in *start..*end { builder.append_value(idx as i32); } } diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs index 45ecb038f..65efad23c 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs @@ -27,7 +27,6 @@ use arrow::{ ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, record_batch::{RecordBatch, RecordBatchOptions}, }; -use arrow::array::Float64Array; use blaze_jni_bridge::{ jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, }; @@ -157,6 +156,7 @@ impl Agg for SparkUDAFWrapper { ) -> Result<()> { log::info!("start partial update"); let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); + log::info!("update before accs.num {}", accs.num_records()); let params = partial_args.to_vec(); let params_schema = self @@ -209,7 +209,7 @@ impl Agg for SparkUDAFWrapper { partial_arg_idx, ) .unwrap(); - + log::info!("update after accs.num {}", accs.num_records()); Ok(()) } @@ -282,12 +282,13 @@ impl AccColumn for AccUnsafeRowsColumn { } fn resize(&mut self, len: usize) { - unimplemented!() + let rows = jni_call_static!( + BlazeUnsafeRowsWrapperUtils.create(len as i32)-> JObject) + .unwrap(); + self.obj = jni_new_global_ref!(rows.as_obj()).unwrap(); } - fn shrink_to_fit(&mut self) { - unimplemented!() - } + fn shrink_to_fit(&mut self) {} fn num_records(&self) -> usize { match jni_call_static!( @@ -333,8 +334,7 @@ impl AccColumn for AccUnsafeRowsColumn { for i in 0..binary_array.len() { if binary_array.is_valid(i) { let bytes = binary_array.value(i).to_vec(); - array[i] = bytes; - log::info!("freeze arrary {} : {:?}", i, array[i]); + array[i].extend_from_slice(&bytes); } else { log::warn!("AccUnsafeRowsColumn::freeze_to_rows : binary_array null error") } @@ -343,8 +343,7 @@ impl AccColumn for AccUnsafeRowsColumn { } fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> { - - log::info!("unfreeze array {:?}", array.clone()); + log::info!("unfreeze array {:?}", array); log::info!("unfreeze offsets {:?}", offsets); let binary_values = array.iter().map(|&data| data).collect(); let offsets_i32 = offsets @@ -433,10 +432,6 @@ fn partial_update_udaf( let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); let struct_array = StructArray::from(params_batch.clone()); - log::info!("input struct_array {:?}", struct_array.fields()); - log::info!("batch{:?}", params_batch.column(0).into_data()); - let column0= params_batch.column(0).as_any().downcast_ref::().unwrap().values().to_vec(); - log::info!("column0 {:?}", column0); let mut export_ffi_batch_array = FFI_ArrowArray::new(&struct_array.to_data()); diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java index 8fa075f6a..ac4ef3957 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java @@ -33,6 +33,6 @@ public static int getRowNum(InternalRow[] unsafeRows) { } public static InternalRow[] getEmptyObject(int rowNum) { - return UnsafeRowsWrapper$.MODULE$.getNullObject(rowNum); + return UnsafeRowsWrapper$.MODULE$.getEmptyObject(rowNum); } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index c4011ac27..3dd03f35f 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -1173,20 +1173,15 @@ object NativeConverters extends Logging { } aggBuilder.setAggFunction(pb.AggFunction.DECLARATIVE) val convertedChildren = mutable.LinkedHashMap[pb.PhysicalExprNode, BoundReference]() - val bound = declarative.mapChildren(_.transformDown { - case p: Literal => p - case p => - try { - val convertedChild = - convertExprWithFallback(p, isPruningExpr = false, fallbackToError) - val nextBindIndex = convertedChildren.size - convertedChildren.getOrElseUpdate( - convertedChild, - BoundReference(nextBindIndex, p.dataType, p.nullable)) - } catch { - case _: Exception | _: NotImplementedError => p - } - }) + + val bound = declarative.mapChildren { p => + val convertedChild = convertExpr(p) + val nextBindIndex = convertedChildren.size + declarative.inputAggBufferAttributes.length + convertedChildren.getOrElseUpdate( + convertedChild, + BoundReference(nextBindIndex, p.dataType, p.nullable)) + } + val paramsSchema = StructType( convertedChildren.values .map(ref => StructField("", ref.dataType, ref.nullable)) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 675b201c4..9e7cd9648 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -68,6 +68,11 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() private val inputSchema = ArrowUtils.toArrowSchema(javaParamsSchema) + private val paramsToUnsafe = { + val toUnsafe = UnsafeProjection.create(javaParamsSchema) + toUnsafe.initialize(Option(TaskContext.get()).map(_.partitionId()).getOrElse(0)) + toUnsafe + } private val indexSchema = { @@ -81,6 +86,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } val dataTypes: Seq[DataType] = expr.aggBufferAttributes.map(_.dataType) + val dataName: Seq[String] = expr.aggBufferAttributes.map(_.name) val inputTypes: Seq[DataType] = javaParamsSchema.map(_.dataType) @@ -97,35 +103,45 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { ArrowArray.wrap(importIdxFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => // import into params root Data.importIntoVectorSchemaRoot(batchAllocator, inputArray, inputRoot, dictionaryProvider) - val batch = ColumnarHelper.rootAsBatch(inputRoot) - val inputRows = ColumnarHelper.batchAsRowIter(batch).toArray + val inputRows = ColumnarHelper.rootAsBatch(inputRoot) Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) val fieldVectors = idxRoot.getFieldVectors.asScala val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] val inputIdxVector = fieldVectors(1).asInstanceOf[IntVector] + logInfo(s"inputRows.num: ${inputRows.numRows()}") + logInfo(s"rows.num: ${rows.length}") + logInfo(s"Idx length ${idxRoot.getRowCount}") + logInfo(s"inputIdxVector $inputIdxVector") + logInfo(s"rowIdxVector $rowIdxVector") for (i <- 0 until idxRoot.getRowCount) { - val row = rows(rowIdxVector.get(i)) - logInfo(s"javaParamsSchema $javaParamsSchema") - logInfo(s"inputIdxVector: ${inputIdxVector}") - logInfo(s"inputRows.length ${inputRows.length}") - val input = inputRows(inputIdxVector.get(i)) - logInfo(s"is unsafe row: ${input.isInstanceOf[UnsafeRow]}") - logInfo(s"input numField ${input.numFields} $inputSchema") - if (input.numFields > 0) { - logInfo(s"input row: ${input.toSeq(inputTypes)}") - val joiner = new JoinedRow - if (row.numFields == 0) { - rows(rowIdxVector.get(i)) = updater(joiner(initialize(), input)) - } else { - rows(rowIdxVector.get(i)) = updater(joiner(row, input)) + if ( inputIdxVector.get(i) < inputRows.numRows() ) { + if (rowIdxVector.get(i) < rows.length) { + val row = rows(rowIdxVector.get(i)) + val input = inputRows.getRow(inputIdxVector.get(i)) + val joiner = new JoinedRow + if (row.numFields == 0) { + rows(rowIdxVector.get(i)) = updater(joiner(initialize(), paramsToUnsafe(input))).copy() + } else { + // logInfo(s"row: ${row.toSeq(dataTypes)}") + // logInfo(s"input: ${input.toSeq(inputTypes)}") + // logInfo(s"is row unsafe ${row.isInstanceOf[UnsafeRow]}") + rows(rowIdxVector.get(i)) = updater(joiner(row, paramsToUnsafe(input))).copy() + } + // logInfo(s"temp row 0: ${rows(0).toSeq(dataTypes)}") + } + else { + logInfo(s"wow $i rowIdx:${rowIdxVector.get(i)}") + } } - logInfo(s"temp row 0: ${rows(0).toSeq(dataTypes)}") + + else { + logInfo(s"wow update i $i inputIdxVector:${inputIdxVector.get(i)}") } } logInfo(s"update rows num: ${rows.length}, rows.fieldnum:${rows(0).numFields}") - logInfo(s"row 0: ${rows(0).toString}") +// logInfo(s"row 0: ${rows(0).toString}") rows } } @@ -145,16 +161,32 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] val mergeIdxVector = fieldVectors(1).asInstanceOf[IntVector] + logInfo(s"rows.num: ${rows.length}, mergeRows.num ${mergeRows.length} , idx.len: ${idxRoot.getRowCount}") + logInfo(s"mergeIdxVector $mergeIdxVector") + logInfo(s"rowIdxVector $rowIdxVector") for (i <- 0 until idxRoot.getRowCount) { - val row = rows(rowIdxVector.get(i)) - val mergeRow = mergeRows(mergeIdxVector.get(i)) - if (mergeRow.numFields > 0) { - val joiner = new JoinedRow - if (row.numFields == 0) { - rows(rowIdxVector.get(i)) = merger(joiner(initialize(), mergeRow)) - } else { - rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)) + if (mergeIdxVector.get(i) < mergeRows.length) { + if (rowIdxVector.get(i) < rows.length) { + logInfo(s"i: $i, mergeIdxVector.get(i) ${mergeIdxVector.get(i)}, rowIdxVector.get(i) ${rowIdxVector.get(i)}") + val row = rows(rowIdxVector.get(i)) + val mergeRow = mergeRows(mergeIdxVector.get(i)) + val joiner = new JoinedRow + if (row.numFields == 0) { + rows(rowIdxVector.get(i)) = merger(joiner(initialize(), mergeRow)).copy() + logInfo { + s"init merge row ${rows(rowIdxVector.get(i)).toSeq(dataTypes)}" + } + } else { + rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)).copy() + logInfo { + s"merge row ${rows(rowIdxVector.get(i)).toSeq(dataTypes)}" + } + } + } + else { + logInfo(s"wow merge i $i rowIdxVector:${rowIdxVector.get(i)}") } + } } logInfo("finish merge in scalar!!") @@ -180,7 +212,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { // evaluate expression and write to output root val outputWriter = ArrowWriter.create(outputRoot) for (i <- 0 until idxRoot.getRowCount) { - val row = rows(rowIdxVector.get(i)) + val row = rows(rowIdxVector.get(i)).copy() outputWriter.write(evaluator(row)) } outputWriter.finish() diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala index 102300ea8..46e74364c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.Row import scala.collection.JavaConverters._ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import java.lang.reflect.Field object UnsafeRowsWrapper extends Logging { @@ -92,20 +91,20 @@ object UnsafeRowsWrapper extends Logging { val idxArray = paramsRoot.getFieldVectors.asScala.head.asInstanceOf[IntVector] val serializer = new UnsafeRowSerializer(numFields).newInstance() val outputWriter = ArrowWriter.create(outputRoot) - logInfo(s"freeze unsaferows num: ${unsafeRows.length}") - logInfo(s"freeze idxArray $idxArray") +// logInfo(s"freeze unsaferows num: ${unsafeRows.length}") +// logInfo(s"freeze idxArray $idxArray") for (idx <- 0 until paramsRoot.getRowCount) { val internalRow = unsafeRows(idxArray.get(idx)) if (internalRow.numFields == 0) { outputWriter.write(toUnsafeRow(Row(Array.empty[Byte]), Array(BinaryType))) } else { - logInfo(s"freeze unsafe row : ${internalRow.toString}") +// logInfo(s"freeze unsafe row : ${internalRow.toString}") Utils.tryWithResource(new ByteArrayOutputStream()) { baos => val serializerStream = serializer.serializeStream(baos) serializerStream.writeValue(internalRow) serializerStream.close() val bytes = baos.toByteArray - logInfo(s"write bytes : ${java.util.Arrays.toString(bytes)}") +// logInfo(s"write bytes : ${java.util.Arrays.toString(bytes)}") outputWriter.write(toUnsafeRow(Row(bytes), Array(BinaryType))) } } @@ -150,18 +149,16 @@ object UnsafeRowsWrapper extends Logging { for (i <- 0 until paramsRoot.getRowCount) { val bytes = binaryVector.get(i) val offset = intVector.get(i) - logInfo(s"Reading bytes from offset: $offset, bytes length: ${bytes.length}") +// logInfo(s"Reading bytes from offset: $offset, bytes length: ${bytes.length}") if (bytes.length - offset > 0) { val internalRow: InternalRow = Utils.tryWithResource( new ByteArrayInputStream(bytes, offset, bytes.length - offset)) { bais => - val unsafeRow = - deserializer.deserializeStream(bais).readValue().asInstanceOf[UnsafeRow] - // get offset use reflect - val field: Field = classOf[ByteArrayInputStream].getDeclaredField("pos") - field.setAccessible(true) - val position = field.getInt(bais) - outputWriter.write(toUnsafeRow(Row(position), Array(IntegerType))) - logInfo(s"unfreeze row ${unsafeRow.toString}") + val deserializeStream = deserializer.deserializeStream(bais) + val unsafeRow = deserializeStream.readValue().asInstanceOf[UnsafeRow] + deserializeStream.close() + val size = unsafeRow.getSizeInBytes + 4 + outputWriter.write(toUnsafeRow(Row(offset + size), Array(IntegerType))) +// logInfo(s"unfreeze row ${unsafeRow.toString}") unsafeRow } internalRowsArray(i) = internalRow @@ -190,7 +187,7 @@ object UnsafeRowsWrapper extends Logging { unsafeRows.length } - def getNullObject(rowNum: Int): Array[InternalRow] = { + def getEmptyObject(rowNum: Int): Array[InternalRow] = { Array.fill(rowNum) { new UnsafeRow(0) } From 666a12e747fc9a10661f91f8dc7b340a864cee4e Mon Sep 17 00:00:00 2001 From: guoying06 Date: Wed, 12 Feb 2025 17:10:15 +0800 Subject: [PATCH 06/17] fix Arrow scope issue and serialization of row(none) issue --- .../src/agg/spark_hdaf_wrapper.rs | 7 -- .../sql/blaze/SparkUDAFWrapperContext.scala | 64 +++---------------- .../spark/sql/blaze/UnsafeRowsWrapper.scala | 16 ----- 3 files changed, 10 insertions(+), 77 deletions(-) diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs index 65efad23c..9b6a12c26 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs @@ -154,9 +154,7 @@ impl Agg for SparkUDAFWrapper { partial_arg_idx: IdxSelection<'_>, batch_schema: SchemaRef, ) -> Result<()> { - log::info!("start partial update"); let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); - log::info!("update before accs.num {}", accs.num_records()); let params = partial_args.to_vec(); let params_schema = self @@ -209,7 +207,6 @@ impl Agg for SparkUDAFWrapper { partial_arg_idx, ) .unwrap(); - log::info!("update after accs.num {}", accs.num_records()); Ok(()) } @@ -220,7 +217,6 @@ impl Agg for SparkUDAFWrapper { merging_accs: &mut AccColumnRef, merging_acc_idx: IdxSelection<'_>, ) -> Result<()> { - log::info!("start partial merge"); let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); let merging_accs = downcast_any!(merging_accs, mut AccUnsafeRowsColumn).unwrap(); @@ -260,7 +256,6 @@ impl Agg for SparkUDAFWrapper { } fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { - log::info!("start final merge"); let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); final_merge_udaf( self.jcontext()?, @@ -343,8 +338,6 @@ impl AccColumn for AccUnsafeRowsColumn { } fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> { - log::info!("unfreeze array {:?}", array); - log::info!("unfreeze offsets {:?}", offsets); let binary_values = array.iter().map(|&data| data).collect(); let offsets_i32 = offsets .iter() diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 9e7cd9648..f6005308f 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, JoinedRow, Nondeterministic, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, JoinedRow, MutableProjection, Nondeterministic, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} @@ -85,16 +85,10 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { ArrowUtils.toArrowSchema(schema) } - val dataTypes: Seq[DataType] = expr.aggBufferAttributes.map(_.dataType) - val dataName: Seq[String] = expr.aggBufferAttributes.map(_.name) - - val inputTypes: Seq[DataType] = javaParamsSchema.map(_.dataType) - def update( rows: Array[InternalRow], importIdxFFIArrayPtr: Long, importBatchFFIArrayPtr: Long): Array[InternalRow] = { - logInfo("start partial update in scalar!") Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( VectorSchemaRoot.create(inputSchema, batchAllocator), @@ -110,38 +104,18 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] val inputIdxVector = fieldVectors(1).asInstanceOf[IntVector] - logInfo(s"inputRows.num: ${inputRows.numRows()}") - logInfo(s"rows.num: ${rows.length}") - logInfo(s"Idx length ${idxRoot.getRowCount}") - logInfo(s"inputIdxVector $inputIdxVector") - logInfo(s"rowIdxVector $rowIdxVector") for (i <- 0 until idxRoot.getRowCount) { if ( inputIdxVector.get(i) < inputRows.numRows() ) { - if (rowIdxVector.get(i) < rows.length) { val row = rows(rowIdxVector.get(i)) val input = inputRows.getRow(inputIdxVector.get(i)) val joiner = new JoinedRow if (row.numFields == 0) { - rows(rowIdxVector.get(i)) = updater(joiner(initialize(), paramsToUnsafe(input))).copy() + rows(rowIdxVector.get(i)) = updater(joiner(initialize(), paramsToUnsafe(input).copy())).copy() } else { - // logInfo(s"row: ${row.toSeq(dataTypes)}") - // logInfo(s"input: ${input.toSeq(inputTypes)}") - // logInfo(s"is row unsafe ${row.isInstanceOf[UnsafeRow]}") - rows(rowIdxVector.get(i)) = updater(joiner(row, paramsToUnsafe(input))).copy() + rows(rowIdxVector.get(i)) = updater(joiner(row, paramsToUnsafe(input).copy())).copy() } - // logInfo(s"temp row 0: ${rows(0).toSeq(dataTypes)}") - } - else { - logInfo(s"wow $i rowIdx:${rowIdxVector.get(i)}") - } } - - else { - logInfo(s"wow update i $i inputIdxVector:${inputIdxVector.get(i)}") - } } - logInfo(s"update rows num: ${rows.length}, rows.fieldnum:${rows(0).numFields}") -// logInfo(s"row 0: ${rows(0).toString}") rows } } @@ -151,7 +125,6 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { rows: Array[InternalRow], mergeRows: Array[InternalRow], importIdxFFIArrayPtr: Long): Array[InternalRow] = { - logInfo("start merge in scalar!!") Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( VectorSchemaRoot.create(indexSchema, batchAllocator), @@ -161,35 +134,18 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] val mergeIdxVector = fieldVectors(1).asInstanceOf[IntVector] - logInfo(s"rows.num: ${rows.length}, mergeRows.num ${mergeRows.length} , idx.len: ${idxRoot.getRowCount}") - logInfo(s"mergeIdxVector $mergeIdxVector") - logInfo(s"rowIdxVector $rowIdxVector") for (i <- 0 until idxRoot.getRowCount) { if (mergeIdxVector.get(i) < mergeRows.length) { - if (rowIdxVector.get(i) < rows.length) { - logInfo(s"i: $i, mergeIdxVector.get(i) ${mergeIdxVector.get(i)}, rowIdxVector.get(i) ${rowIdxVector.get(i)}") - val row = rows(rowIdxVector.get(i)) - val mergeRow = mergeRows(mergeIdxVector.get(i)) - val joiner = new JoinedRow - if (row.numFields == 0) { - rows(rowIdxVector.get(i)) = merger(joiner(initialize(), mergeRow)).copy() - logInfo { - s"init merge row ${rows(rowIdxVector.get(i)).toSeq(dataTypes)}" - } - } else { - rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)).copy() - logInfo { - s"merge row ${rows(rowIdxVector.get(i)).toSeq(dataTypes)}" - } - } - } - else { - logInfo(s"wow merge i $i rowIdxVector:${rowIdxVector.get(i)}") + val row = rows(rowIdxVector.get(i)) + val mergeRow = mergeRows(mergeIdxVector.get(i)) + val joiner = new JoinedRow + if (row.numFields == 0) { + rows(rowIdxVector.get(i)) = merger(joiner(initialize(), mergeRow)).copy() + } else { + rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)).copy() } - } } - logInfo("finish merge in scalar!!") rows } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala index 46e74364c..a892bb679 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -91,22 +91,14 @@ object UnsafeRowsWrapper extends Logging { val idxArray = paramsRoot.getFieldVectors.asScala.head.asInstanceOf[IntVector] val serializer = new UnsafeRowSerializer(numFields).newInstance() val outputWriter = ArrowWriter.create(outputRoot) -// logInfo(s"freeze unsaferows num: ${unsafeRows.length}") -// logInfo(s"freeze idxArray $idxArray") for (idx <- 0 until paramsRoot.getRowCount) { val internalRow = unsafeRows(idxArray.get(idx)) - if (internalRow.numFields == 0) { - outputWriter.write(toUnsafeRow(Row(Array.empty[Byte]), Array(BinaryType))) - } else { -// logInfo(s"freeze unsafe row : ${internalRow.toString}") Utils.tryWithResource(new ByteArrayOutputStream()) { baos => val serializerStream = serializer.serializeStream(baos) serializerStream.writeValue(internalRow) serializerStream.close() val bytes = baos.toByteArray -// logInfo(s"write bytes : ${java.util.Arrays.toString(bytes)}") outputWriter.write(toUnsafeRow(Row(bytes), Array(BinaryType))) - } } } @@ -149,8 +141,6 @@ object UnsafeRowsWrapper extends Logging { for (i <- 0 until paramsRoot.getRowCount) { val bytes = binaryVector.get(i) val offset = intVector.get(i) -// logInfo(s"Reading bytes from offset: $offset, bytes length: ${bytes.length}") - if (bytes.length - offset > 0) { val internalRow: InternalRow = Utils.tryWithResource( new ByteArrayInputStream(bytes, offset, bytes.length - offset)) { bais => val deserializeStream = deserializer.deserializeStream(bais) @@ -158,15 +148,9 @@ object UnsafeRowsWrapper extends Logging { deserializeStream.close() val size = unsafeRow.getSizeInBytes + 4 outputWriter.write(toUnsafeRow(Row(offset + size), Array(IntegerType))) -// logInfo(s"unfreeze row ${unsafeRow.toString}") unsafeRow } internalRowsArray(i) = internalRow - } else { - internalRowsArray(i) = new UnsafeRow(0) - outputWriter.write(toUnsafeRow(Row(offset), Array(IntegerType))) - } - } outputWriter.finish() From 9b3d656bc028104cd93fd992da6895a5c89cb748 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Thu, 13 Feb 2025 17:11:46 +0800 Subject: [PATCH 07/17] fix udaf init --- .../blaze-jni-bridge/src/jni_bridge.rs | 22 ++++++---- .../src/agg/spark_hdaf_wrapper.rs | 19 +++++---- .../sql/blaze/UnsafeRowsWrapperUtils.java | 3 -- .../sql/blaze/SparkUDAFWrapperContext.scala | 41 ++++++++++++------- .../spark/sql/blaze/UnsafeRowsWrapper.scala | 6 --- 5 files changed, 53 insertions(+), 38 deletions(-) diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index 31fa2b858..d5d642aab 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -1177,6 +1177,10 @@ impl<'a> SparkUDFWrapperContext<'a> { pub struct SparkUDAFWrapperContext<'a> { pub class: JClass<'a>, pub ctor: JMethodID, + pub method_initialize: JMethodID, + pub method_initialize_ret: ReturnType, + pub method_resize: JMethodID, + pub method_resize_ret: ReturnType, pub method_update: JMethodID, pub method_update_ret: ReturnType, pub method_merge: JMethodID, @@ -1192,6 +1196,16 @@ impl<'a> SparkUDAFWrapperContext<'a> { Ok(SparkUDAFWrapperContext { class, ctor: env.get_method_id(class, "", "(Ljava/nio/ByteBuffer;)V")?, + method_initialize: env.get_method_id( + class, + "initialize", + "(I)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, + method_initialize_ret: ReturnType::Object, + method_resize: env.get_method_id( + class, + "resize", + "([Lorg/apache/spark/sql/catalyst/InternalRow;I)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, + method_resize_ret: ReturnType::Object, method_update: env.get_method_id( class, "update", @@ -1246,8 +1260,6 @@ pub struct BlazeUnsafeRowsWrapperUtils<'a> { pub method_deserialize_ret: ReturnType, pub method_num: JStaticMethodID, pub method_num_ret: ReturnType, - pub method_create: JStaticMethodID, - pub method_create_ret: ReturnType, } impl<'a> BlazeUnsafeRowsWrapperUtils<'a> { pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils"; @@ -1274,12 +1286,6 @@ impl<'a> BlazeUnsafeRowsWrapperUtils<'a> { "([Lorg/apache/spark/sql/catalyst/InternalRow;)I", )?, method_num_ret: ReturnType::Primitive(Primitive::Int), - method_create: env.get_static_method_id( - class, - "getEmptyObject", - "(I)[Lorg/apache/spark/sql/catalyst/InternalRow;", - )?, - method_create_ret: ReturnType::Object, }) } } diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs index 9b6a12c26..adf9d1ac9 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs @@ -117,13 +117,16 @@ impl Agg for SparkUDAFWrapper { } fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { - // num_rows - let rows = jni_call_static!( - BlazeUnsafeRowsWrapperUtils.create(num_rows as i32)-> JObject) - .unwrap(); + let jcontext = self.jcontext().unwrap(); + let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).initialize( + num_rows as i32, + )-> JObject).unwrap(); + + let jcontext = self.jcontext().unwrap(); let obj = jni_new_global_ref!(rows.as_obj()).unwrap(); Box::new(AccUnsafeRowsColumn { obj, + jcontext, num_fields: self.buffer_schema.fields.len(), }) } @@ -268,6 +271,7 @@ impl Agg for SparkUDAFWrapper { struct AccUnsafeRowsColumn { obj: GlobalRef, + jcontext: GlobalRef, num_fields: usize, } @@ -277,9 +281,10 @@ impl AccColumn for AccUnsafeRowsColumn { } fn resize(&mut self, len: usize) { - let rows = jni_call_static!( - BlazeUnsafeRowsWrapperUtils.create(len as i32)-> JObject) - .unwrap(); + let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()).resize( + self.obj.as_obj(), + len as i32, + )-> JObject).unwrap(); self.obj = jni_new_global_ref!(rows.as_obj()).unwrap(); } diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java index ac4ef3957..3820b166b 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java @@ -32,7 +32,4 @@ public static int getRowNum(InternalRow[] unsafeRows) { return UnsafeRowsWrapper$.MODULE$.getRowNum(unsafeRows); } - public static InternalRow[] getEmptyObject(int rowNum) { - return UnsafeRowsWrapper$.MODULE$.getEmptyObject(rowNum); - } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index f6005308f..2b12be6b6 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWrite import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { private val (expr, javaParamsSchema) = @@ -63,8 +64,6 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { private lazy val evaluator = UnsafeProjection.create(expr.evaluateExpression :: Nil, expr.aggBufferAttributes) - private def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() - private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() private val inputSchema = ArrowUtils.toArrowSchema(javaParamsSchema) @@ -85,6 +84,26 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { ArrowUtils.toArrowSchema(schema) } + val dataTypes: Seq[DataType] = expr.aggBufferAttributes.map(_.dataType) +// private def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() + def initialize(numRow: Int): Array[InternalRow] = { + val initialRow = initializer.apply(InternalRow.empty) + Array.fill(numRow) { + initialRow.copy() + } + } + + def resize(rows: Array[InternalRow], len: Int): Array[InternalRow] = { + val buffer = ArrayBuffer[InternalRow](rows: _*) + if (buffer.length < len) { + val initialRow = initializer.apply(InternalRow.empty) + buffer ++= Array.fill(len - buffer.length){initialRow.copy()} + } else { + buffer.trimEnd(buffer.length - len) + } + buffer.toArray + } + def update( rows: Array[InternalRow], importIdxFFIArrayPtr: Long, @@ -109,11 +128,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { val row = rows(rowIdxVector.get(i)) val input = inputRows.getRow(inputIdxVector.get(i)) val joiner = new JoinedRow - if (row.numFields == 0) { - rows(rowIdxVector.get(i)) = updater(joiner(initialize(), paramsToUnsafe(input).copy())).copy() - } else { - rows(rowIdxVector.get(i)) = updater(joiner(row, paramsToUnsafe(input).copy())).copy() - } + rows(rowIdxVector.get(i)) = updater(joiner(row, paramsToUnsafe(input).copy())).copy() } } rows @@ -134,16 +149,14 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] val mergeIdxVector = fieldVectors(1).asInstanceOf[IntVector] + for (i <- 0 until idxRoot.getRowCount) { - if (mergeIdxVector.get(i) < mergeRows.length) { + val idx = mergeIdxVector.get(i) + if (idx < mergeRows.length) { val row = rows(rowIdxVector.get(i)) val mergeRow = mergeRows(mergeIdxVector.get(i)) val joiner = new JoinedRow - if (row.numFields == 0) { - rows(rowIdxVector.get(i)) = merger(joiner(initialize(), mergeRow)).copy() - } else { - rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)).copy() - } + rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)).copy() } } rows @@ -168,7 +181,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { // evaluate expression and write to output root val outputWriter = ArrowWriter.create(outputRoot) for (i <- 0 until idxRoot.getRowCount) { - val row = rows(rowIdxVector.get(i)).copy() + val row = rows(rowIdxVector.get(i)) outputWriter.write(evaluator(row)) } outputWriter.finish() diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala index a892bb679..830ce2718 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -171,10 +171,4 @@ object UnsafeRowsWrapper extends Logging { unsafeRows.length } - def getEmptyObject(rowNum: Int): Array[InternalRow] = { - Array.fill(rowNum) { - new UnsafeRow(0) - } - } - } From 3c766123c3f57d809f37fe88c2a333b80ec4710a Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Thu, 13 Feb 2025 20:32:20 +0800 Subject: [PATCH 08/17] optimize --- .../blaze-jni-bridge/src/jni_bridge.rs | 34 ++- .../src/agg/spark_hdaf_wrapper.rs | 228 +++++++----------- .../sql/blaze/UnsafeRowsWrapperUtils.java | 15 +- .../sql/blaze/SparkUDAFWrapperContext.scala | 66 +++-- .../spark/sql/blaze/UnsafeRowsWrapper.scala | 124 ++++------ 5 files changed, 180 insertions(+), 287 deletions(-) diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index d5d642aab..7132401f3 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -1199,29 +1199,33 @@ impl<'a> SparkUDAFWrapperContext<'a> { method_initialize: env.get_method_id( class, "initialize", - "(I)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, + "(I)Lscala/collection/mutable/ArrayBuffer;", + )?, method_initialize_ret: ReturnType::Object, method_resize: env.get_method_id( class, "resize", - "([Lorg/apache/spark/sql/catalyst/InternalRow;I)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, - method_resize_ret: ReturnType::Object, + "(Lscala/collection/mutable/ArrayBuffer;I)V", + )?, + method_resize_ret: ReturnType::Primitive(Primitive::Void), method_update: env.get_method_id( class, "update", - "([Lorg/apache/spark/sql/catalyst/InternalRow;JJ)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, - method_update_ret: ReturnType::Object, + "(Lscala/collection/mutable/ArrayBuffer;JJ)V", + )?, + method_update_ret: ReturnType::Primitive(Primitive::Void), method_merge: env.get_method_id( class, "merge", - "([Lorg/apache/spark/sql/catalyst/InternalRow;[Lorg/apache/spark/sql/catalyst/InternalRow;J)[Lorg/apache/spark/sql/catalyst/InternalRow;")?, - method_merge_ret: ReturnType::Object, + "(Lscala/collection/mutable/ArrayBuffer;Lscala/collection/mutable/ArrayBuffer;J)V", + )?, + method_merge_ret: ReturnType::Primitive(Primitive::Void), method_eval: env.get_method_id( class, "eval", - "([Lorg/apache/spark/sql/catalyst/InternalRow;JJ)V")?, + "(Lscala/collection/mutable/ArrayBuffer;JJ)V", + )?, method_eval_ret: ReturnType::Primitive(Primitive::Void), - }) } } @@ -1258,8 +1262,6 @@ pub struct BlazeUnsafeRowsWrapperUtils<'a> { pub method_serialize_ret: ReturnType, pub method_deserialize: JStaticMethodID, pub method_deserialize_ret: ReturnType, - pub method_num: JStaticMethodID, - pub method_num_ret: ReturnType, } impl<'a> BlazeUnsafeRowsWrapperUtils<'a> { pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils"; @@ -1271,21 +1273,15 @@ impl<'a> BlazeUnsafeRowsWrapperUtils<'a> { method_serialize: env.get_static_method_id( class, "serialize", - "([Lorg/apache/spark/sql/catalyst/InternalRow;IJJ)V", + "(Lscala/collection/mutable/ArrayBuffer;IJJ)V", )?, method_serialize_ret: ReturnType::Primitive(Primitive::Void), method_deserialize: env.get_static_method_id( class, "deserialize", - "(IJJ)[Lorg/apache/spark/sql/catalyst/InternalRow;", + "(ILjava/nio/ByteBuffer;)Lscala/collection/mutable/ArrayBuffer;", )?, method_deserialize_ret: ReturnType::Object, - method_num: env.get_static_method_id( - class, - "getRowNum", - "([Lorg/apache/spark/sql/catalyst/InternalRow;)I", - )?, - method_num_ret: ReturnType::Primitive(Primitive::Int), }) } } diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs index adf9d1ac9..c19c2bc42 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs @@ -15,6 +15,7 @@ use std::{ any::Any, fmt::{Debug, Display, Formatter}, + io::Cursor, sync::Arc, }; @@ -27,14 +28,15 @@ use arrow::{ ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, record_batch::{RecordBatch, RecordBatchOptions}, }; +use arrow_schema::FieldRef; use blaze_jni_bridge::{ jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, }; -use datafusion::{ - common::{DataFusionError, Result}, - physical_expr::PhysicalExpr, +use datafusion::{common::Result, physical_expr::PhysicalExpr}; +use datafusion_ext_commons::{ + downcast_any, + io::{read_len, write_len}, }; -use datafusion_ext_commons::downcast_any; use jni::objects::{GlobalRef, JObject}; use once_cell::sync::OnceCell; @@ -120,7 +122,8 @@ impl Agg for SparkUDAFWrapper { let jcontext = self.jcontext().unwrap(); let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).initialize( num_rows as i32, - )-> JObject).unwrap(); + )-> JObject) + .unwrap(); let jcontext = self.jcontext().unwrap(); let obj = jni_new_global_ref!(rows.as_obj()).unwrap(); @@ -128,6 +131,7 @@ impl Agg for SparkUDAFWrapper { obj, jcontext, num_fields: self.buffer_schema.fields.len(), + num_rows, }) } @@ -188,28 +192,16 @@ impl Agg for SparkUDAFWrapper { partial_arg_idx_builder.append_value(partial_arg_idx as i32); } } - let acc_idx = acc_idx_builder - .finish() - .as_any() - .downcast_ref::() - .cloned() - .unwrap(); - - let partial_arg_idx = partial_arg_idx_builder - .finish() - .as_any() - .downcast_ref::() - .cloned() - .unwrap(); + let acc_idx = acc_idx_builder.finish(); + let partial_arg_idx = partial_arg_idx_builder.finish(); - accs.obj = partial_update_udaf( + partial_update_udaf( self.jcontext()?, params_batch, accs.obj.clone(), acc_idx, partial_arg_idx, - ) - .unwrap(); + )?; Ok(()) } @@ -232,29 +224,16 @@ impl Agg for SparkUDAFWrapper { merging_acc_idx_builder.append_value(merging_acc_idx as i32); } } - let acc_idx = acc_idx_builder - .finish() - .as_any() - .downcast_ref::() - .cloned() - .unwrap(); - - let merging_acc_idx = merging_acc_idx_builder - .finish() - .as_any() - .downcast_ref::() - .cloned() - .unwrap(); + let acc_idx = acc_idx_builder.finish(); + let merging_acc_idx = merging_acc_idx_builder.finish(); - accs.obj = partial_merge_udaf( + partial_merge_udaf( self.jcontext()?, accs.obj.clone(), merging_accs.obj.clone(), acc_idx, merging_acc_idx, - ) - .unwrap(); - + )?; Ok(()) } @@ -273,6 +252,7 @@ struct AccUnsafeRowsColumn { obj: GlobalRef, jcontext: GlobalRef, num_fields: usize, + num_rows: usize, } impl AccColumn for AccUnsafeRowsColumn { @@ -284,20 +264,15 @@ impl AccColumn for AccUnsafeRowsColumn { let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()).resize( self.obj.as_obj(), len as i32, - )-> JObject).unwrap(); - self.obj = jni_new_global_ref!(rows.as_obj()).unwrap(); + )-> ()) + .unwrap(); + self.num_rows = len; } fn shrink_to_fit(&mut self) {} fn num_records(&self) -> usize { - match jni_call_static!( - BlazeUnsafeRowsWrapperUtils.num(self.obj.as_obj()) - -> i32) - { - Ok(row_num) => row_num as usize, - Err(_) => 0, - } + self.num_rows } fn mem_used(&self) -> usize { @@ -306,8 +281,8 @@ impl AccColumn for AccUnsafeRowsColumn { fn freeze_to_rows(&self, idx: IdxSelection<'_>, array: &mut [Vec]) -> Result<()> { let field = Arc::new(Field::new("", DataType::Int32, false)); - let idx32 = idx.to_int32_array().into_data(); - let struct_array = StructArray::from(vec![(field, make_array(idx32))]); + let idx_array: ArrayRef = Arc::new(idx.to_int32_array()); + let struct_array = StructArray::from(vec![(field, idx_array)]); let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); let mut import_ffi_array = FFI_ArrowArray::empty(); jni_call_static!( @@ -325,79 +300,42 @@ impl AccColumn for AccUnsafeRowsColumn { make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); let result_struct = import_struct_array.as_struct(); - let binary_array = result_struct - .column(0) - .as_any() - .downcast_ref::() - .ok_or_else(|| DataFusionError::Execution("Expected a BinaryArray".to_string()))?; - - for i in 0..binary_array.len() { - if binary_array.is_valid(i) { - let bytes = binary_array.value(i).to_vec(); - array[i].extend_from_slice(&bytes); - } else { - log::warn!("AccUnsafeRowsColumn::freeze_to_rows : binary_array null error") - } + let binary_array = downcast_any!(result_struct.column(0), BinaryArray)?; + let data = binary_array.value(0); + + // UnsafeRow is serialized with big-endian i32 length prefix + let mut cur = 0; + for i in 0..array.len() { + let bytes_len = i32::from_be_bytes(data[cur..][..4].try_into().unwrap()) as usize; + write_len(bytes_len, &mut array[i])?; + cur += 4; + + array[i].extend_from_slice(&data[cur..][..bytes_len]); + cur += bytes_len; } Ok(()) } fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> { - let binary_values = array.iter().map(|&data| data).collect(); - let offsets_i32 = offsets - .iter() - .map(|data| *data as i32) - .collect::>(); - let offsets_array = Int32Array::from(offsets_i32) - .as_any() - .downcast_ref::() - .cloned() - .unwrap(); - let binary_array = BinaryArray::from_vec(binary_values) - .as_any() - .downcast_ref::() - .cloned() - .unwrap(); - - let binary_field = Arc::new(Field::new("", DataType::Binary, false)); - let offsets_field = Arc::new(Field::new("", DataType::Int32, false)); - - let struct_array = StructArray::from(vec![ - (binary_field.clone(), make_array(binary_array.into_data())), - (offsets_field.clone(), make_array(offsets_array.into_data())), - ]); + let mut data = vec![]; + for (row_data, offset) in array.iter().zip(offsets) { + let mut cur = Cursor::new(&row_data[*offset..]); + let bytes_len = read_len(&mut cur)?; + data.extend_from_slice(&(bytes_len as i32).to_be_bytes()); + *offset += cur.position() as usize; + + data.extend_from_slice(&row_data[*offset..][..bytes_len]); + *offset += bytes_len; + } - let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); - let mut import_ffi_array = FFI_ArrowArray::empty(); + let data_buffer = jni_new_direct_byte_buffer!(data)?; let rows = jni_call_static!( BlazeUnsafeRowsWrapperUtils.deserialize( self.num_fields as i32, - &mut export_ffi_array as *mut FFI_ArrowArray as i64, - &mut import_ffi_array as *mut FFI_ArrowArray as i64,) + data_buffer.as_obj()) -> JObject)?; self.obj = jni_new_global_ref!(rows.as_obj())?; - - // update offsets - // import output from context - let field = Field::new("", DataType::Int32, false); - let schema = Schema::new(vec![field]); - let import_ffi_schema = FFI_ArrowSchema::try_from(schema)?; - let import_struct_array = - make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); - let result_struct = import_struct_array.as_struct(); - - let int32array = result_struct - .column(0) - .as_any() - .downcast_ref::() - .ok_or_else(|| DataFusionError::Execution("Expected a Int32Array".to_string()))?; - - assert_eq!(int32array.len(), array.len()); - - for i in 0..int32array.len() { - offsets[i] = int32array.value(i) as usize; - } - + self.num_rows = array.len(); Ok(()) } @@ -410,36 +348,45 @@ impl AccColumn for AccUnsafeRowsColumn { } } +fn int32_field() -> FieldRef { + static FIELD: OnceCell = OnceCell::new(); + FIELD + .get_or_init(|| Arc::new(Field::new("", DataType::Int32, false))) + .clone() +} + +fn binary_field() -> FieldRef { + static FIELD: OnceCell = OnceCell::new(); + FIELD + .get_or_init(|| Arc::new(Field::new("", DataType::Binary, false))) + .clone() +} + fn partial_update_udaf( jcontext: GlobalRef, params_batch: RecordBatch, accs: GlobalRef, acc_idx: Int32Array, partial_arg_idx: Int32Array, -) -> Result { - let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int32, false)); - let partial_arg_idx_field = Arc::new(Field::new("partial_arg_idx", DataType::Int32, false)); - - let struct_array = StructArray::from(vec![ - (acc_idx_field.clone(), make_array(acc_idx.into_data())), - ( - partial_arg_idx_field.clone(), - make_array(partial_arg_idx.into_data()), - ), +) -> Result<()> { + let acc_idx: ArrayRef = Arc::new(acc_idx); + let partial_arg_idx: ArrayRef = Arc::new(partial_arg_idx); + let idx_struct_array = StructArray::from(vec![ + (int32_field(), acc_idx), + (int32_field(), partial_arg_idx), ]); - let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); - - let struct_array = StructArray::from(params_batch.clone()); + let batch_struct_array = StructArray::from(params_batch); - let mut export_ffi_batch_array = FFI_ArrowArray::new(&struct_array.to_data()); + let mut export_ffi_idx_array = FFI_ArrowArray::new(&idx_struct_array.to_data()); + let mut export_ffi_batch_array = FFI_ArrowArray::new(&batch_struct_array.to_data()); let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).update( accs.as_obj(), &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, &mut export_ffi_batch_array as *mut FFI_ArrowArray as i64, - )-> JObject)?; + )-> ())?; - jni_new_global_ref!(rows.as_obj()) + Ok(()) } fn partial_merge_udaf( @@ -448,26 +395,22 @@ fn partial_merge_udaf( merging_accs: GlobalRef, acc_idx: Int32Array, merging_acc_idx: Int32Array, -) -> Result { - let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int32, false)); - let merging_acc_idx_field = Arc::new(Field::new("merging_acc_idx", DataType::Int32, false)); - - let struct_array = StructArray::from(vec![ - (acc_idx_field.clone(), make_array(acc_idx.into_data())), - ( - merging_acc_idx_field.clone(), - make_array(merging_acc_idx.into_data()), - ), +) -> Result<()> { + let acc_idx: ArrayRef = Arc::new(acc_idx); + let merging_acc_idx: ArrayRef = Arc::new(merging_acc_idx); + let export_ffi_idx_array = StructArray::from(vec![ + (int32_field(), acc_idx), + (int32_field(), merging_acc_idx), ]); - let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); + let mut export_ffi_idx_array = FFI_ArrowArray::new(&export_ffi_idx_array.to_data()); let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).merge( accs.as_obj(), merging_accs.as_obj(), &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, - )-> JObject)?; + )-> ())?; - jni_new_global_ref!(rows.as_obj()) + Ok(()) } fn final_merge_udaf( @@ -476,9 +419,8 @@ fn final_merge_udaf( acc_idx: IdxSelection<'_>, result_schema: SchemaRef, ) -> Result { - let acc_idx_field = Arc::new(Field::new("acc_idx", DataType::Int32, false)); - let acc_idx = acc_idx.to_int32_array().into_data(); - let struct_array = StructArray::from(vec![(acc_idx_field.clone(), make_array(acc_idx))]); + let acc_idx: ArrayRef = Arc::new(Int32Array::from(acc_idx.to_int32_array())); + let struct_array = StructArray::from(vec![(int32_field(), acc_idx)]); let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); let mut import_ffi_array = FFI_ArrowArray::empty(); let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).eval( diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java index 3820b166b..a0cb86047 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java @@ -15,21 +15,20 @@ */ package org.apache.spark.sql.blaze; +import java.nio.ByteBuffer; + import org.apache.spark.sql.catalyst.InternalRow; +import scala.collection.mutable.ArrayBuffer; + public class UnsafeRowsWrapperUtils { public static void serialize( - InternalRow[] unsafeRows, int numFields, long importFFIArrayPtr, long exportFFIArrayPtr) { + ArrayBuffer unsafeRows, int numFields, long importFFIArrayPtr, long exportFFIArrayPtr) { UnsafeRowsWrapper$.MODULE$.serialize(unsafeRows, numFields, importFFIArrayPtr, exportFFIArrayPtr); } - public static InternalRow[] deserialize(int numFields, long importFFIArrayPtr, long exportFFIArrayPtr) { - return UnsafeRowsWrapper$.MODULE$.deserialize(numFields, importFFIArrayPtr, exportFFIArrayPtr); + public static ArrayBuffer deserialize(int numFields, ByteBuffer dataBuffer) { + return UnsafeRowsWrapper$.MODULE$.deserialize(numFields, dataBuffer); } - - public static int getRowNum(InternalRow[] unsafeRows) { - return UnsafeRowsWrapper$.MODULE$.getRowNum(unsafeRows); - } - } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 2b12be6b6..08d5c227f 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -16,6 +16,7 @@ package org.apache.spark.sql.blaze import scala.collection.JavaConverters._ + import org.apache.arrow.c.{ArrowArray, Data} import org.apache.arrow.vector.{IntVector, VectorSchemaRoot} import org.apache.arrow.vector.dictionary.DictionaryProvider @@ -29,8 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} - import java.nio.ByteBuffer + import scala.collection.mutable.ArrayBuffer case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { @@ -85,29 +86,25 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } val dataTypes: Seq[DataType] = expr.aggBufferAttributes.map(_.dataType) -// private def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() - def initialize(numRow: Int): Array[InternalRow] = { - val initialRow = initializer.apply(InternalRow.empty) - Array.fill(numRow) { - initialRow.copy() - } + + def initialize(numRow: Int): ArrayBuffer[InternalRow] = { + val rows = ArrayBuffer[InternalRow]() + resize(rows, numRow) + rows } - def resize(rows: Array[InternalRow], len: Int): Array[InternalRow] = { - val buffer = ArrayBuffer[InternalRow](rows: _*) - if (buffer.length < len) { - val initialRow = initializer.apply(InternalRow.empty) - buffer ++= Array.fill(len - buffer.length){initialRow.copy()} + def resize(rows: ArrayBuffer[InternalRow], len: Int): Unit = { + if (rows.length < len) { + rows.append(Range(rows.length, len).map(_ => initializer.apply(InternalRow.empty)) :_*) } else { - buffer.trimEnd(buffer.length - len) + rows.trimEnd(rows.length - len) } - buffer.toArray } def update( - rows: Array[InternalRow], + rows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long, - importBatchFFIArrayPtr: Long): Array[InternalRow] = { + importBatchFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( VectorSchemaRoot.create(inputSchema, batchAllocator), @@ -120,52 +117,45 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] + val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] val inputIdxVector = fieldVectors(1).asInstanceOf[IntVector] for (i <- 0 until idxRoot.getRowCount) { - if ( inputIdxVector.get(i) < inputRows.numRows() ) { - val row = rows(rowIdxVector.get(i)) - val input = inputRows.getRow(inputIdxVector.get(i)) - val joiner = new JoinedRow - rows(rowIdxVector.get(i)) = updater(joiner(row, paramsToUnsafe(input).copy())).copy() - } + val rowIdx = rowIdxVector.get(i) + val row = rows(rowIdx) + val input = paramsToUnsafe(inputRows.getRow(inputIdxVector.get(i))) + rows(rowIdx) = updater(new JoinedRow(row, input)).copy() } - rows } } } def merge( - rows: Array[InternalRow], - mergeRows: Array[InternalRow], - importIdxFFIArrayPtr: Long): Array[InternalRow] = { + rows: ArrayBuffer[InternalRow], + mergeRows: ArrayBuffer[InternalRow], + importIdxFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( VectorSchemaRoot.create(indexSchema, batchAllocator), ArrowArray.wrap(importIdxFFIArrayPtr)) { (idxRoot, idxArray) => Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] + val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] val mergeIdxVector = fieldVectors(1).asInstanceOf[IntVector] - for (i <- 0 until idxRoot.getRowCount) { - val idx = mergeIdxVector.get(i) - if (idx < mergeRows.length) { - val row = rows(rowIdxVector.get(i)) - val mergeRow = mergeRows(mergeIdxVector.get(i)) - val joiner = new JoinedRow - rows(rowIdxVector.get(i)) = merger(joiner(row, mergeRow)).copy() - } + val rowIdx = rowIdxVector.get(i) + val mergeIdx = mergeIdxVector.get(i) + val row = rows(rowIdx) + val mergeRow = mergeRows(mergeIdx) + rows(rowIdx) = merger(new JoinedRow(row, mergeRow)).copy() } - rows } } } def eval( - rows: Array[InternalRow], + rows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long, exportFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala index 830ce2718..448b0cbc9 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -16,7 +16,7 @@ package org.apache.spark.sql.blaze import org.apache.arrow.c.{ArrowArray, Data} -import org.apache.arrow.vector.{VarBinaryVector, IntVector, VectorSchemaRoot} +import org.apache.arrow.vector.{IntVector, VarBinaryVector, VectorSchemaRoot} import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -28,9 +28,15 @@ import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StructField, StructType} import org.apache.spark.sql.Row - import scala.collection.JavaConverters._ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.OutputStream +import java.nio.ByteBuffer +import java.nio.ByteOrder + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.util.ByteBufferInputStream object UnsafeRowsWrapper extends Logging { @@ -40,7 +46,7 @@ object UnsafeRowsWrapper extends Logging { ArrowUtils.toArrowSchema(schema) } - private val byteSchema = { + private val dataSchema = { val schema = StructType(Seq(StructField("", BinaryType, nullable = false))) ArrowUtils.toArrowSchema(schema) } @@ -71,104 +77,64 @@ object UnsafeRowsWrapper extends Logging { } def serialize( - unsafeRows: Array[InternalRow], - numFields: Int, - importFFIArrayPtr: Long, - exportFFIArrayPtr: Long): Unit = { + rows: ArrayBuffer[InternalRow], + numFields: Int, + importFFIArrayPtr: Long, + exportFFIArrayPtr: Long): Unit = { + Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( - VectorSchemaRoot.create(byteSchema, batchAllocator), + VectorSchemaRoot.create(dataSchema, batchAllocator), VectorSchemaRoot.create(idxSchema, batchAllocator), - ArrowArray.wrap(importFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { - (outputRoot, paramsRoot, importArray, exportArray) => + ) { (exportDataRoot, importIdxRoot) => + + Using.resources( + ArrowArray.wrap(importFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { (importArray, exportArray) => + // import into params root Data.importIntoVectorSchemaRoot( batchAllocator, importArray, - paramsRoot, + importIdxRoot, dictionaryProvider) - val idxArray = paramsRoot.getFieldVectors.asScala.head.asInstanceOf[IntVector] + + // write serialized row into sequential raw bytes + val importIdxArray = importIdxRoot.getFieldVectors.get(0).asInstanceOf[IntVector] + val outputDataStream = new ByteArrayOutputStream() val serializer = new UnsafeRowSerializer(numFields).newInstance() - val outputWriter = ArrowWriter.create(outputRoot) - for (idx <- 0 until paramsRoot.getRowCount) { - val internalRow = unsafeRows(idxArray.get(idx)) - Utils.tryWithResource(new ByteArrayOutputStream()) { baos => - val serializerStream = serializer.serializeStream(baos) - serializerStream.writeValue(internalRow) - serializerStream.close() - val bytes = baos.toByteArray - outputWriter.write(toUnsafeRow(Row(bytes), Array(BinaryType))) + Using(serializer.serializeStream(outputDataStream)) { ser => + for (idx <- 0 until importIdxRoot.getRowCount) { + val rowIdx = importIdxArray.get(idx) + val row = rows(rowIdx) + ser.writeValue(row) } } + // export serialized data as a single row batch using root allocator + val outputWriter = ArrowWriter.create(exportDataRoot) + outputWriter.write(InternalRow(outputDataStream.toByteArray)) outputWriter.finish() - - // export to output using root allocator Data.exportVectorSchemaRoot( ArrowUtils.rootAllocator, - outputRoot, + exportDataRoot, dictionaryProvider, exportArray) + } } } } - def deserialize( - numFields: Int, - importFFIArrayPtr: Long, - exportFFIArrayPtr: Long): Array[InternalRow] = { - - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resources( - VectorSchemaRoot.create(deserializeSchema, batchAllocator), - VectorSchemaRoot.create(offsetSchema, batchAllocator), - ArrowArray.wrap(importFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { - (paramsRoot, outputRoot, importArray, exportArray) => - Data.importIntoVectorSchemaRoot( - batchAllocator, - importArray, - paramsRoot, - dictionaryProvider) - val fieldVectors = paramsRoot.getFieldVectors.asScala - val binaryVector = fieldVectors.head.asInstanceOf[VarBinaryVector]; - val intVector = fieldVectors(1).asInstanceOf[IntVector] - - val deserializer = new UnsafeRowSerializer(numFields).newInstance() - val internalRowsArray = new Array[InternalRow](paramsRoot.getRowCount) - val outputWriter = ArrowWriter.create(outputRoot) - for (i <- 0 until paramsRoot.getRowCount) { - val bytes = binaryVector.get(i) - val offset = intVector.get(i) - val internalRow: InternalRow = Utils.tryWithResource( - new ByteArrayInputStream(bytes, offset, bytes.length - offset)) { bais => - val deserializeStream = deserializer.deserializeStream(bais) - val unsafeRow = deserializeStream.readValue().asInstanceOf[UnsafeRow] - deserializeStream.close() - val size = unsafeRow.getSizeInBytes + 4 - outputWriter.write(toUnsafeRow(Row(offset + size), Array(IntegerType))) - unsafeRow - } - internalRowsArray(i) = internalRow - } - - outputWriter.finish() + def deserialize(numFields: Int, dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { + val deserializer = new UnsafeRowSerializer(numFields).newInstance() + val inputDataStream = new ByteBufferInputStream(dataBuffer) + val rows = new ArrayBuffer[InternalRow]() - // export to output using root allocator - Data.exportVectorSchemaRoot( - ArrowUtils.rootAllocator, - outputRoot, - dictionaryProvider, - exportArray) - - internalRowsArray + Using.resource(deserializer.deserializeStream(inputDataStream)) { deser => + for (row <- deser.asKeyValueIterator.map(_._2.asInstanceOf[UnsafeRow].copy())) { + rows.append(row) } } + rows } - - def getRowNum(unsafeRows: Array[InternalRow]): Int = { - unsafeRows.length - } - -} +} \ No newline at end of file From 6f79877ebab267a66bb39f7883a1eec97fed1615 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Mon, 17 Feb 2025 11:53:37 +0800 Subject: [PATCH 09/17] complete declarative agg --- native-engine/datafusion-ext-plans/src/agg/agg.rs | 2 +- .../datafusion-ext-plans/src/agg/bloom_filter.rs | 2 +- .../datafusion-ext-plans/src/agg/collect.rs | 2 +- .../datafusion-ext-plans/src/agg/count.rs | 2 +- .../datafusion-ext-plans/src/agg/first.rs | 2 +- .../src/agg/first_ignores_null.rs | 2 +- .../datafusion-ext-plans/src/agg/maxmin.rs | 2 +- native-engine/datafusion-ext-plans/src/agg/mod.rs | 2 +- ...spark_hdaf_wrapper.rs => spark_udaf_wrapper.rs} | 12 ++++++------ native-engine/datafusion-ext-plans/src/agg/sum.rs | 2 +- .../spark/sql/blaze/SparkUDAFWrapperContext.scala | 14 +++++++++----- 11 files changed, 24 insertions(+), 20 deletions(-) rename native-engine/datafusion-ext-plans/src/agg/{spark_hdaf_wrapper.rs => spark_udaf_wrapper.rs} (96%) diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 4d2166acf..39d56c791 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -24,7 +24,7 @@ use datafusion_ext_exprs::cast::TryCastExpr; use crate::agg::{ acc::AccColumnRef, avg, bloom_filter, brickhouse, collect, first, first_ignores_null, maxmin, - spark_hdaf_wrapper::SparkUDAFWrapper, sum, AggFunction, + spark_udaf_wrapper::SparkUDAFWrapper, sum, AggFunction, }; pub trait Agg: Send + Sync + Debug { diff --git a/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs b/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs index 9064e611c..04fb7c5e1 100644 --- a/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs +++ b/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs @@ -114,7 +114,7 @@ impl Agg for AggBloomFilter { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, + _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccBloomFilterColumn).unwrap(); let bloom_filter = match acc_idx { diff --git a/native-engine/datafusion-ext-plans/src/agg/collect.rs b/native-engine/datafusion-ext-plans/src/agg/collect.rs index 07e90a3b0..c498a3e30 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect.rs @@ -114,7 +114,7 @@ impl Agg for AggGenericCollect { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, + _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut C).unwrap(); idx_for_zipped! { diff --git a/native-engine/datafusion-ext-plans/src/agg/count.rs b/native-engine/datafusion-ext-plans/src/agg/count.rs index 19d8a4d90..80f6a30ca 100644 --- a/native-engine/datafusion-ext-plans/src/agg/count.rs +++ b/native-engine/datafusion-ext-plans/src/agg/count.rs @@ -92,7 +92,7 @@ impl Agg for AggCount { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, + _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccCountColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/first.rs b/native-engine/datafusion-ext-plans/src/agg/first.rs index 0190c618b..5c0496a52 100644 --- a/native-engine/datafusion-ext-plans/src/agg/first.rs +++ b/native-engine/datafusion-ext-plans/src/agg/first.rs @@ -90,7 +90,7 @@ impl Agg for AggFirst { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, + _batch_schema: SchemaRef, ) -> Result<()> { let partial_arg = &partial_args[0]; let accs = downcast_any!(accs, mut AccFirstColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs b/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs index e211066ee..c6a5e6be0 100644 --- a/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs +++ b/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs @@ -86,7 +86,7 @@ impl Agg for AggFirstIgnoresNull { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, + _batch_schema: SchemaRef, ) -> Result<()> { let partial_arg = &partial_args[0]; let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs index e1571d708..20e2889c3 100644 --- a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs +++ b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs @@ -93,7 +93,7 @@ impl Agg for AggMaxMin

{ acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, + _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); let old_heap_mem_used = accs.items_heap_mem_used(acc_idx); diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index ba4d5c0ca..5f4075349 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -25,7 +25,7 @@ pub mod count; pub mod first; pub mod first_ignores_null; pub mod maxmin; -mod spark_hdaf_wrapper; +mod spark_udaf_wrapper; pub mod sum; use std::{fmt::Debug, sync::Arc}; diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs similarity index 96% rename from native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs rename to native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index c19c2bc42..b14cbbce8 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_hdaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -21,7 +21,7 @@ use std::{ use arrow::{ array::{ - as_struct_array, make_array, Array, ArrayAccessor, ArrayRef, AsArray, BinaryArray, Datum, + as_struct_array, make_array, Array, ArrayRef, AsArray, BinaryArray, Int32Array, Int32Builder, StructArray, }, datatypes::{DataType, Field, Schema, SchemaRef}, @@ -135,7 +135,7 @@ impl Agg for SparkUDAFWrapper { }) } - fn with_new_exprs(&self, exprs: Vec>) -> Result> { + fn with_new_exprs(&self, _exprs: Vec>) -> Result> { Ok(Arc::new(Self::try_new( self.serialized.clone(), self.buffer_schema.clone(), @@ -261,7 +261,7 @@ impl AccColumn for AccUnsafeRowsColumn { } fn resize(&mut self, len: usize) { - let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()).resize( + jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()).resize( self.obj.as_obj(), len as i32, )-> ()) @@ -380,7 +380,7 @@ fn partial_update_udaf( let mut export_ffi_idx_array = FFI_ArrowArray::new(&idx_struct_array.to_data()); let mut export_ffi_batch_array = FFI_ArrowArray::new(&batch_struct_array.to_data()); - let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).update( + jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).update( accs.as_obj(), &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, &mut export_ffi_batch_array as *mut FFI_ArrowArray as i64, @@ -404,7 +404,7 @@ fn partial_merge_udaf( ]); let mut export_ffi_idx_array = FFI_ArrowArray::new(&export_ffi_idx_array.to_data()); - let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).merge( + jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).merge( accs.as_obj(), merging_accs.as_obj(), &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, @@ -423,7 +423,7 @@ fn final_merge_udaf( let struct_array = StructArray::from(vec![(int32_field(), acc_idx)]); let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); let mut import_ffi_array = FFI_ArrowArray::empty(); - let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).eval( + jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).eval( accs.as_obj(), &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, &mut import_ffi_array as *mut FFI_ArrowArray as i64, diff --git a/native-engine/datafusion-ext-plans/src/agg/sum.rs b/native-engine/datafusion-ext-plans/src/agg/sum.rs index 49686f78d..5e073553f 100644 --- a/native-engine/datafusion-ext-plans/src/agg/sum.rs +++ b/native-engine/datafusion-ext-plans/src/agg/sum.rs @@ -91,7 +91,7 @@ impl Agg for AggSum { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, + _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 08d5c227f..009edd414 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -16,7 +16,6 @@ package org.apache.spark.sql.blaze import scala.collection.JavaConverters._ - import org.apache.arrow.c.{ArrowArray, Data} import org.apache.arrow.vector.{IntVector, VectorSchemaRoot} import org.apache.arrow.vector.dictionary.DictionaryProvider @@ -25,13 +24,13 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, JoinedRow, MutableProjection, Nondeterministic, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} -import java.nio.ByteBuffer +import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructField, StructType} +import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { @@ -74,6 +73,10 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { toUnsafe } + private val outputSchema = { + val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable))) + ArrowUtils.toArrowSchema(schema) + } private val indexSchema = { val schema = StructType(Seq(StructField("", IntegerType), StructField("", IntegerType))) @@ -85,6 +88,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { ArrowUtils.toArrowSchema(schema) } + val dataTypes: Seq[DataType] = expr.aggBufferAttributes.map(_.dataType) def initialize(numRow: Int): ArrayBuffer[InternalRow] = { @@ -161,7 +165,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( VectorSchemaRoot.create(evalIndexSchema, batchAllocator), - VectorSchemaRoot.create(inputSchema, batchAllocator), + VectorSchemaRoot.create(outputSchema, batchAllocator), ArrowArray.wrap(importIdxFFIArrayPtr), ArrowArray.wrap(exportFFIArrayPtr)) { (idxRoot, outputRoot, idxArray, exportArray) => Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) From 32e57b8a8df73caff08be3625966c048cd4a2a95 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Tue, 18 Feb 2025 19:44:44 +0800 Subject: [PATCH 10/17] init TypedImperativeEvaluator udaf --- .../src/agg/spark_udaf_wrapper.rs | 6 +- .../sql/blaze/UnsafeRowsWrapperUtils.java | 2 - .../spark/sql/blaze/NativeConverters.scala | 58 ++++---- .../sql/blaze/SparkUDAFWrapperContext.scala | 137 ++++++++++++++---- 4 files changed, 138 insertions(+), 65 deletions(-) diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index b14cbbce8..04dab6c47 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -21,8 +21,8 @@ use std::{ use arrow::{ array::{ - as_struct_array, make_array, Array, ArrayRef, AsArray, BinaryArray, - Int32Array, Int32Builder, StructArray, + as_struct_array, make_array, Array, ArrayRef, AsArray, BinaryArray, Int32Array, + Int32Builder, StructArray, }, datatypes::{DataType, Field, Schema, SchemaRef}, ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, @@ -261,7 +261,7 @@ impl AccColumn for AccUnsafeRowsColumn { } fn resize(&mut self, len: usize) { - jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()).resize( + jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()).resize( self.obj.as_obj(), len as i32, )-> ()) diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java index a0cb86047..025b774fc 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java @@ -16,9 +16,7 @@ package org.apache.spark.sql.blaze; import java.nio.ByteBuffer; - import org.apache.spark.sql.catalyst.InternalRow; - import scala.collection.mutable.ArrayBuffer; public class UnsafeRowsWrapperUtils { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index 3dd03f35f..4bf745377 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -27,18 +27,8 @@ import com.google.protobuf.ByteString import org.apache.spark.SparkEnv import org.blaze.{protobuf => pb} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Abs, Acos, Add, Alias, And, Asin, Atan, AttributeReference, BitwiseAnd, BitwiseOr, BoundReference, CaseWhen, Cast, Ceil, CheckOverflow, Coalesce, Concat, ConcatWs, Contains, Cos, CreateArray, CreateNamedStruct, Divide, EndsWith, EqualTo, Exp, Expression, Floor, GetArrayItem, GetMapValue, GetStructField, GreaterThan, GreaterThanOrEqual, If, In, InSet, IsNotNull, IsNull, Length, LessThan, LessThanOrEqual, Like, Literal, Log, Log10, Log2, Lower, MakeDecimal, Md5, Multiply, Murmur3Hash, Not, NullIf, OctetLength, Or, Remainder, Sha2, ShiftLeft, ShiftRight, Signum, Sin, Sqrt, StartsWith, StringRepeat, StringSpace, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, Tan, TruncDate, Unevaluable, UnscaledValue, Upper} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.aggregate.Average -import org.apache.spark.sql.catalyst.expressions.aggregate.CollectList -import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet -import org.apache.spark.sql.catalyst.expressions.aggregate.Count -import org.apache.spark.sql.catalyst.expressions.aggregate.Max -import org.apache.spark.sql.catalyst.expressions.aggregate.Min -import org.apache.spark.sql.catalyst.expressions.aggregate.Sum -import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.aggregate.First +import org.apache.spark.sql.catalyst.expressions.{Abs, Acos, Add, Alias, And, Asin, Atan, Attribute, AttributeReference, BitwiseAnd, BitwiseOr, BoundReference, CaseWhen, Cast, Ceil, CheckOverflow, Coalesce, Concat, ConcatWs, Contains, Cos, CreateArray, CreateNamedStruct, DayOfMonth, Divide, EndsWith, EqualTo, Exp, Expression, Floor, GetArrayItem, GetJsonObject, GetMapValue, GetStructField, GreaterThan, GreaterThanOrEqual, If, In, InSet, IsNotNull, IsNull, LeafExpression, Length, LessThan, LessThanOrEqual, Like, Literal, Log, Log10, Log2, Lower, MakeDecimal, Md5, Month, Multiply, Murmur3Hash, Not, NullIf, OctetLength, Or, Remainder, Sha2, ShiftLeft, ShiftRight, Signum, Sin, Sqrt, StartsWith, StringRepeat, StringSpace, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, Tan, TruncDate, Unevaluable, UnscaledValue, Upper, XxHash64, Year} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, ImperativeAggregate, Max, Min, Sum, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode import org.apache.spark.sql.catalyst.plans.FullOuter @@ -50,12 +40,6 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.DayOfMonth -import org.apache.spark.sql.catalyst.expressions.GetJsonObject -import org.apache.spark.sql.catalyst.expressions.LeafExpression -import org.apache.spark.sql.catalyst.expressions.Month -import org.apache.spark.sql.catalyst.expressions.XxHash64 -import org.apache.spark.sql.catalyst.expressions.Year import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.execution.blaze.plan.Util import org.apache.spark.sql.execution.ScalarSubquery @@ -1165,21 +1149,31 @@ object NativeConverters extends Logging { defaultValue = true) => aggBuilder.setAggFunction(pb.AggFunction.BRICKHOUSE_COMBINE_UNIQUE) aggBuilder.addChildren(convertExpr(udaf.children.head)) - // other DeclarativeAggregate - case declarative - if classOf[DeclarativeAggregate].isAssignableFrom(e.aggregateFunction.getClass) => - def fallbackToError: Expression => pb.PhysicalExprNode = { e => - throw new NotImplementedError(s"unsupported declarative expression: (${e.getClass}) $e") - } + // other udaf aggFunction + case udaf + if classOf[DeclarativeAggregate].isAssignableFrom(e.aggregateFunction.getClass) + || classOf[TypedImperativeAggregate[_]].isAssignableFrom(e.aggregateFunction.getClass) => aggBuilder.setAggFunction(pb.AggFunction.DECLARATIVE) val convertedChildren = mutable.LinkedHashMap[pb.PhysicalExprNode, BoundReference]() - val bound = declarative.mapChildren { p => - val convertedChild = convertExpr(p) - val nextBindIndex = convertedChildren.size + declarative.inputAggBufferAttributes.length - convertedChildren.getOrElseUpdate( - convertedChild, - BoundReference(nextBindIndex, p.dataType, p.nullable)) + val bound = udaf match { + case declarativeAggregate: DeclarativeAggregate => + declarativeAggregate.mapChildren { p => + val convertedChild = convertExpr(p) + val nextBindIndex = + convertedChildren.size + declarativeAggregate.inputAggBufferAttributes.length + convertedChildren.getOrElseUpdate( + convertedChild, + BoundReference(nextBindIndex, p.dataType, p.nullable)) + } + case imperativeAggregate: ImperativeAggregate => + imperativeAggregate.mapChildren { p => + val convertedChild = convertExpr(p) + val nextBindIndex = convertedChildren.size + convertedChildren.getOrElseUpdate( + convertedChild, + BoundReference(nextBindIndex, p.dataType, p.nullable)) + } } val paramsSchema = StructType( @@ -1189,14 +1183,14 @@ object NativeConverters extends Logging { val serialized = serializeExpression( - bound.asInstanceOf[DeclarativeAggregate with Serializable], + bound.asInstanceOf[AggregateFunction with Serializable], paramsSchema) aggBuilder.setUdaf( pb.AggUdaf .newBuilder() .setSerialized(ByteString.copyFrom(serialized)) - .setAggBufferSchema(NativeConverters.convertSchema(declarative.aggBufferSchema)) + .setAggBufferSchema(NativeConverters.convertSchema(udaf.aggBufferSchema)) .setReturnType(convertDataType(bound.dataType)) .setReturnNullable(bound.nullable)) aggBuilder.addAllChildren(convertedChildren.keys.asJava) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 009edd414..5121b57ca 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -24,18 +24,19 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, JoinedRow, MutableProjection, Nondeterministic, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BindReferences, GenericInternalRow, JoinedRow, Nondeterministic, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { private val (expr, javaParamsSchema) = - NativeConverters.deserializeExpression[DeclarativeAggregate]({ + NativeConverters.deserializeExpression[AggregateFunction]({ val bytes = new Array[Byte](serialized.remaining()) serialized.get(bytes) bytes @@ -52,17 +53,14 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { case _ => } - private lazy val initializer = UnsafeProjection.create(expr.initialValues) - - private lazy val updater = - UnsafeProjection.create(expr.updateExpressions, expr.aggBufferAttributes ++ inputAttributes) - - private lazy val merger = UnsafeProjection.create( - expr.mergeExpressions, - expr.aggBufferAttributes ++ expr.inputAggBufferAttributes) - - private lazy val evaluator = - UnsafeProjection.create(expr.evaluateExpression :: Nil, expr.aggBufferAttributes) + private val aggEvaluator = expr match { + case declarative: DeclarativeAggregate => + logInfo(s"init DeclarativeEvaluator") + new DeclarativeEvaluator(declarative, inputAttributes) + case imperative: TypedImperativeAggregate[_] => + logInfo(s"init TypedImperativeEvaluator") + new TypedImperativeEvaluator(imperative, inputAttributes) + } private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() @@ -89,24 +87,22 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } - val dataTypes: Seq[DataType] = expr.aggBufferAttributes.map(_.dataType) - - def initialize(numRow: Int): ArrayBuffer[InternalRow] = { - val rows = ArrayBuffer[InternalRow]() + def initialize(numRow: Int): ArrayBuffer[UnsafeRow] = { + val rows = ArrayBuffer[UnsafeRow]() resize(rows, numRow) rows } - def resize(rows: ArrayBuffer[InternalRow], len: Int): Unit = { + def resize(rows: ArrayBuffer[UnsafeRow], len: Int): Unit = { if (rows.length < len) { - rows.append(Range(rows.length, len).map(_ => initializer.apply(InternalRow.empty)) :_*) + rows.append(Range(rows.length, len).map(_ => aggEvaluator.initialize()): _*) } else { rows.trimEnd(rows.length - len) } } def update( - rows: ArrayBuffer[InternalRow], + rows: ArrayBuffer[UnsafeRow], importIdxFFIArrayPtr: Long, importBatchFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => @@ -128,15 +124,15 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { val rowIdx = rowIdxVector.get(i) val row = rows(rowIdx) val input = paramsToUnsafe(inputRows.getRow(inputIdxVector.get(i))) - rows(rowIdx) = updater(new JoinedRow(row, input)).copy() + rows(rowIdx) = aggEvaluator.update(row, input) } } } } def merge( - rows: ArrayBuffer[InternalRow], - mergeRows: ArrayBuffer[InternalRow], + rows: ArrayBuffer[UnsafeRow], + mergeRows: ArrayBuffer[UnsafeRow], importIdxFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( @@ -152,14 +148,14 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { val mergeIdx = mergeIdxVector.get(i) val row = rows(rowIdx) val mergeRow = mergeRows(mergeIdx) - rows(rowIdx) = merger(new JoinedRow(row, mergeRow)).copy() + rows(rowIdx) = aggEvaluator.merge(row, mergeRow) } } } } def eval( - rows: ArrayBuffer[InternalRow], + rows: ArrayBuffer[UnsafeRow], importIdxFFIArrayPtr: Long, exportFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => @@ -176,7 +172,7 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { val outputWriter = ArrowWriter.create(outputRoot) for (i <- 0 until idxRoot.getRowCount) { val row = rows(rowIdxVector.get(i)) - outputWriter.write(evaluator(row)) + outputWriter.write(aggEvaluator.eval(row)) } outputWriter.finish() @@ -190,3 +186,88 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } } } + +trait AggregateEvaluator extends Logging { + def initialize(): UnsafeRow + def update(mutableAggBuffer: UnsafeRow, row: UnsafeRow): UnsafeRow + def merge(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow + def eval(row: UnsafeRow): InternalRow +} + +class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attribute]) + extends AggregateEvaluator { + + private lazy val initializer = UnsafeProjection.create(agg.initialValues) + + private lazy val updater = + UnsafeProjection.create(agg.updateExpressions, agg.aggBufferAttributes ++ inputAttributes) + + private lazy val merger = UnsafeProjection.create( + agg.mergeExpressions, + agg.aggBufferAttributes ++ agg.inputAggBufferAttributes) + + private lazy val evaluator = + UnsafeProjection.create(agg.evaluateExpression :: Nil, agg.aggBufferAttributes) + + new MapDictionaryProvider() + + override def initialize(): UnsafeRow = { + initializer.apply(InternalRow.empty) + } + + override def update(mutableAggBuffer: UnsafeRow, row: UnsafeRow): UnsafeRow = { + updater(new JoinedRow(mutableAggBuffer, row)).copy() + } + + override def merge(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow = { + merger(new JoinedRow(row1, row2)).copy() + } + + override def eval(row: UnsafeRow): InternalRow = { + evaluator(row) + } +} + +class TypedImperativeEvaluator[T](agg: TypedImperativeAggregate[T], inputAttributes: Seq[Attribute]) + extends AggregateEvaluator { + + private val bufferSchema = agg.aggBufferAttributes.map(_.dataType) + + private def getBufferObject(bufferRow: UnsafeRow): T = { + agg.deserialize(bufferRow.getBytes) + } + + + override def initialize(): UnsafeRow = { +// val byteBuffer = agg.serialize(agg.createAggregationBuffer()) +// val unsafeRow = new UnsafeRow(bufferSchema.length) +// unsafeRow.pointTo(byteBuffer, byteBuffer.length) +// logInfo(s"bufferSchema: $bufferSchema") + val buffer = new SpecificInternalRow(bufferSchema) + agg.initialize(buffer) +// logInfo(s"buffer $buffer") + val writer = new UnsafeRowWriter(1) + writer.write(0, buffer.getBinary(0)) + val unsafeRow = writer.getRow +// logInfo(s"init unsaferow $unsafeRow") + unsafeRow + } + + override def update(buffer: UnsafeRow, row: UnsafeRow): UnsafeRow = { + val bufferObject = agg.update(getBufferObject(buffer), row) + val byteBuffer = agg.serialize(bufferObject).clone() + buffer.pointTo(byteBuffer, byteBuffer.length) + buffer + } + + override def merge(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow = { + val bufferObject = agg.merge(getBufferObject(row1), getBufferObject(row2)) + val byteBuffer = agg.serialize(bufferObject).clone() + row1.pointTo(byteBuffer, byteBuffer.length) + row1 + } + + override def eval(row: UnsafeRow): InternalRow = { + InternalRow(agg.eval(getBufferObject(row))) + } +} From e3ba97bded90e722e23e5f0afd14c9cec4d6f9da Mon Sep 17 00:00:00 2001 From: guoying06 Date: Wed, 19 Feb 2025 06:53:34 +0000 Subject: [PATCH 11/17] Dev udaf new merge NativeConverters --- .../blaze-jni-bridge/src/jni_bridge.rs | 49 ++-- native-engine/blaze-serde/proto/blaze.proto | 5 +- native-engine/blaze-serde/src/from_proto.rs | 3 - .../datafusion-ext-plans/src/agg/agg.rs | 2 - .../src/agg/spark_udaf_wrapper.rs | 75 ++--- .../spark/sql/blaze/NativeConverters.scala | 9 +- .../sql/blaze/SparkUDAFWrapperContext.scala | 259 ++++++++++++------ .../sql/blaze/SparkUDFWrapperContext.scala | 2 +- .../sql/blaze/SparkUDTFWrapperContext.scala | 2 +- .../spark/sql/blaze/UnsafeRowsWrapper.scala | 4 +- 10 files changed, 248 insertions(+), 162 deletions(-) diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index 7132401f3..698a1ce85 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -415,7 +415,6 @@ pub struct JavaClasses<'a> { pub cBlazeConf: BlazeConf<'a>, pub cBlazeRssPartitionWriterBase: BlazeRssPartitionWriterBase<'a>, pub cBlazeCallNativeWrapper: BlazeCallNativeWrapper<'a>, - pub cBlazeUnsafeRowsWrapperUtils: BlazeUnsafeRowsWrapperUtils<'a>, pub cBlazeOnHeapSpillManager: BlazeOnHeapSpillManager<'a>, pub cBlazeNativeParquetSinkUtils: BlazeNativeParquetSinkUtils<'a>, pub cBlazeBlockObject: BlazeBlockObject<'a>, @@ -476,7 +475,6 @@ impl JavaClasses<'static> { cSparkUDAFWrapperContext: SparkUDAFWrapperContext::new(env)?, cSparkUDTFWrapperContext: SparkUDTFWrapperContext::new(env)?, cBlazeConf: BlazeConf::new(env)?, - cBlazeUnsafeRowsWrapperUtils: BlazeUnsafeRowsWrapperUtils::new(env)?, cBlazeRssPartitionWriterBase: BlazeRssPartitionWriterBase::new(env)?, cBlazeCallNativeWrapper: BlazeCallNativeWrapper::new(env)?, cBlazeOnHeapSpillManager: BlazeOnHeapSpillManager::new(env)?, @@ -1187,6 +1185,10 @@ pub struct SparkUDAFWrapperContext<'a> { pub method_merge_ret: ReturnType, pub method_eval: JMethodID, pub method_eval_ret: ReturnType, + pub method_serializeRows: JMethodID, + pub method_serializeRows_ret: ReturnType, + pub method_deserializeRows: JMethodID, + pub method_deserializeRows_ret: ReturnType, } impl<'a> SparkUDAFWrapperContext<'a> { pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/SparkUDAFWrapperContext"; @@ -1226,6 +1228,18 @@ impl<'a> SparkUDAFWrapperContext<'a> { "(Lscala/collection/mutable/ArrayBuffer;JJ)V", )?, method_eval_ret: ReturnType::Primitive(Primitive::Void), + method_serializeRows: env.get_method_id( + class, + "serializeRows", + "(Lscala/collection/mutable/ArrayBuffer;JJ)V", + )?, + method_serializeRows_ret: ReturnType::Primitive(Primitive::Void), + method_deserializeRows: env.get_method_id( + class, + "deserializeRows", + "(Ljava/nio/ByteBuffer;)Lscala/collection/mutable/ArrayBuffer;", + )?, + method_deserializeRows_ret: ReturnType::Object, }) } } @@ -1255,37 +1269,6 @@ impl<'a> SparkUDTFWrapperContext<'a> { } } -#[allow(non_snake_case)] -pub struct BlazeUnsafeRowsWrapperUtils<'a> { - pub class: JClass<'a>, - pub method_serialize: JStaticMethodID, - pub method_serialize_ret: ReturnType, - pub method_deserialize: JStaticMethodID, - pub method_deserialize_ret: ReturnType, -} -impl<'a> BlazeUnsafeRowsWrapperUtils<'a> { - pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils"; - - pub fn new(env: &JNIEnv<'a>) -> JniResult> { - let class = get_global_jclass(env, Self::SIG_TYPE)?; - Ok(BlazeUnsafeRowsWrapperUtils { - class, - method_serialize: env.get_static_method_id( - class, - "serialize", - "(Lscala/collection/mutable/ArrayBuffer;IJJ)V", - )?, - method_serialize_ret: ReturnType::Primitive(Primitive::Void), - method_deserialize: env.get_static_method_id( - class, - "deserialize", - "(ILjava/nio/ByteBuffer;)Lscala/collection/mutable/ArrayBuffer;", - )?, - method_deserialize_ret: ReturnType::Object, - }) - } -} - #[allow(non_snake_case)] pub struct BlazeCallNativeWrapper<'a> { pub class: JClass<'a>, diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index 4d9efb802..2ef736c0d 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto @@ -152,9 +152,8 @@ message PhysicalAggExprNode { message AggUdaf { bytes serialized = 1; - Schema agg_buffer_schema = 2; - ArrowType return_type = 3; - bool return_nullable = 4; + ArrowType return_type = 2; + bool return_nullable = 3; } message PhysicalIsNull { diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index dcc6d133d..7c3b4b55e 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs @@ -444,11 +444,8 @@ impl TryInto> for &protobuf::PhysicalPlanNode { AggFunction::Declarative => { let udaf = agg_node.udaf.as_ref().unwrap(); let serialized = udaf.serialized.clone(); - let agg_buffer_schema = - Arc::new(convert_required!(udaf.agg_buffer_schema)?); create_declarative_agg( serialized, - agg_buffer_schema, convert_required!(udaf.return_type)?, agg_children_exprs, )? diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 39d56c791..f11f0e6a5 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -312,13 +312,11 @@ pub fn create_agg( pub fn create_declarative_agg( serialized: Vec, - buffer_schema: SchemaRef, return_type: DataType, children: Vec>, ) -> Result> { Ok(Arc::new(SparkUDAFWrapper::try_new( serialized, - buffer_schema, return_type, children, )?)) diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index 04dab6c47..e9bcf8eac 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -29,9 +29,7 @@ use arrow::{ record_batch::{RecordBatch, RecordBatchOptions}, }; use arrow_schema::FieldRef; -use blaze_jni_bridge::{ - jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, -}; +use blaze_jni_bridge::{jni_call, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object}; use datafusion::{common::Result, physical_expr::PhysicalExpr}; use datafusion_ext_commons::{ downcast_any, @@ -51,7 +49,6 @@ use crate::{ pub struct SparkUDAFWrapper { serialized: Vec, - pub buffer_schema: SchemaRef, pub return_type: DataType, child: Vec>, import_schema: SchemaRef, @@ -62,13 +59,11 @@ pub struct SparkUDAFWrapper { impl SparkUDAFWrapper { pub fn try_new( serialized: Vec, - buffer_schema: SchemaRef, return_type: DataType, child: Vec>, ) -> Result { Ok(Self { serialized, - buffer_schema, return_type: return_type.clone(), child, import_schema: Arc::new(Schema::new(vec![Field::new("", return_type, true)])), @@ -130,7 +125,6 @@ impl Agg for SparkUDAFWrapper { Box::new(AccUnsafeRowsColumn { obj, jcontext, - num_fields: self.buffer_schema.fields.len(), num_rows, }) } @@ -138,7 +132,6 @@ impl Agg for SparkUDAFWrapper { fn with_new_exprs(&self, _exprs: Vec>) -> Result> { Ok(Arc::new(Self::try_new( self.serialized.clone(), - self.buffer_schema.clone(), self.return_type.clone(), self.child.clone(), )?)) @@ -180,7 +173,7 @@ impl Agg for SparkUDAFWrapper { let params_batch = RecordBatch::try_new_with_options( params_schema.clone(), params.clone(), - &RecordBatchOptions::new().with_row_count(Some(params[0].len())), + &RecordBatchOptions::new().with_row_count(Some(partial_arg_idx.len())), )?; let max_len = std::cmp::max(acc_idx.len(), partial_arg_idx.len()); @@ -251,7 +244,6 @@ impl Agg for SparkUDAFWrapper { struct AccUnsafeRowsColumn { obj: GlobalRef, jcontext: GlobalRef, - num_fields: usize, num_rows: usize, } @@ -280,22 +272,19 @@ impl AccColumn for AccUnsafeRowsColumn { } fn freeze_to_rows(&self, idx: IdxSelection<'_>, array: &mut [Vec]) -> Result<()> { - let field = Arc::new(Field::new("", DataType::Int32, false)); let idx_array: ArrayRef = Arc::new(idx.to_int32_array()); - let struct_array = StructArray::from(vec![(field, idx_array)]); + let struct_array = + StructArray::from(RecordBatch::try_new(index_schema(), vec![idx_array])?); let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); let mut import_ffi_array = FFI_ArrowArray::empty(); - jni_call_static!( - BlazeUnsafeRowsWrapperUtils.serialize( + jni_call!( + SparkUDAFWrapperContext(self.jcontext.as_obj()).serializeRows( self.obj.as_obj(), - self.num_fields as i32, &mut export_ffi_array as *mut FFI_ArrowArray as i64, &mut import_ffi_array as *mut FFI_ArrowArray as i64,) -> ())?; // import output from context - let field = Field::new("", DataType::Binary, false); - let schema = Schema::new(vec![field]); - let import_ffi_schema = FFI_ArrowSchema::try_from(schema)?; + let import_ffi_schema = FFI_ArrowSchema::try_from(serialized_row_schema().as_ref())?; let import_struct_array = make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); let result_struct = import_struct_array.as_struct(); @@ -329,11 +318,8 @@ impl AccColumn for AccUnsafeRowsColumn { } let data_buffer = jni_new_direct_byte_buffer!(data)?; - let rows = jni_call_static!( - BlazeUnsafeRowsWrapperUtils.deserialize( - self.num_fields as i32, - data_buffer.as_obj()) - -> JObject)?; + let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()) + .deserializeRows(data_buffer.as_obj()) -> JObject)?; self.obj = jni_new_global_ref!(rows.as_obj())?; self.num_rows = array.len(); Ok(()) @@ -362,6 +348,27 @@ fn binary_field() -> FieldRef { .clone() } +fn index_schema() -> SchemaRef { + static SCHEMA: OnceCell = OnceCell::new(); + SCHEMA + .get_or_init(|| Arc::new(Schema::new(vec![int32_field()]))) + .clone() +} + +fn index_tuple_schema() -> SchemaRef { + static SCHEMA: OnceCell = OnceCell::new(); + SCHEMA + .get_or_init(|| Arc::new(Schema::new(vec![int32_field(), int32_field()]))) + .clone() +} + +fn serialized_row_schema() -> SchemaRef { + static SCHEMA: OnceCell = OnceCell::new(); + SCHEMA + .get_or_init(|| Arc::new(Schema::new(vec![binary_field()]))) + .clone() +} + fn partial_update_udaf( jcontext: GlobalRef, params_batch: RecordBatch, @@ -371,10 +378,10 @@ fn partial_update_udaf( ) -> Result<()> { let acc_idx: ArrayRef = Arc::new(acc_idx); let partial_arg_idx: ArrayRef = Arc::new(partial_arg_idx); - let idx_struct_array = StructArray::from(vec![ - (int32_field(), acc_idx), - (int32_field(), partial_arg_idx), - ]); + let idx_struct_array = StructArray::from(RecordBatch::try_new( + index_tuple_schema(), + vec![acc_idx, partial_arg_idx], + )?); let batch_struct_array = StructArray::from(params_batch); let mut export_ffi_idx_array = FFI_ArrowArray::new(&idx_struct_array.to_data()); @@ -398,11 +405,11 @@ fn partial_merge_udaf( ) -> Result<()> { let acc_idx: ArrayRef = Arc::new(acc_idx); let merging_acc_idx: ArrayRef = Arc::new(merging_acc_idx); - let export_ffi_idx_array = StructArray::from(vec![ - (int32_field(), acc_idx), - (int32_field(), merging_acc_idx), - ]); - let mut export_ffi_idx_array = FFI_ArrowArray::new(&export_ffi_idx_array.to_data()); + let idx_struct_array = StructArray::from(RecordBatch::try_new( + index_tuple_schema(), + vec![acc_idx, merging_acc_idx], + )?); + let mut export_ffi_idx_array = FFI_ArrowArray::new(&idx_struct_array.to_data()); jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).merge( accs.as_obj(), @@ -420,8 +427,8 @@ fn final_merge_udaf( result_schema: SchemaRef, ) -> Result { let acc_idx: ArrayRef = Arc::new(Int32Array::from(acc_idx.to_int32_array())); - let struct_array = StructArray::from(vec![(int32_field(), acc_idx)]); - let mut export_ffi_idx_array = FFI_ArrowArray::new(&struct_array.to_data()); + let idx_struct_array = StructArray::from(RecordBatch::try_new(index_schema(), vec![acc_idx])?); + let mut export_ffi_idx_array = FFI_ArrowArray::new(&idx_struct_array.to_data()); let mut import_ffi_array = FFI_ArrowArray::empty(); jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).eval( accs.as_obj(), diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index 4bf745377..7b3d933c0 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -1190,7 +1190,6 @@ object NativeConverters extends Logging { pb.AggUdaf .newBuilder() .setSerialized(ByteString.copyFrom(serialized)) - .setAggBufferSchema(NativeConverters.convertSchema(udaf.aggBufferSchema)) .setReturnType(convertDataType(bound.dataType)) .setReturnNullable(bound.nullable)) aggBuilder.addAllChildren(convertedChildren.keys.asJava) @@ -1300,13 +1299,13 @@ object NativeConverters extends Logging { } } - def deserializeExpression[E <: Expression]( - serialized: Array[Byte]): (E with Serializable, StructType) = { + def deserializeExpression[E <: Expression, S <: Serializable]( + serialized: Array[Byte]): (E with Serializable, S) = { Utils.tryWithResource(new ByteArrayInputStream(serialized)) { bis => Utils.tryWithResource(new ObjectInputStream(bis)) { ois => val expr = ois.readObject().asInstanceOf[E with Serializable] - val paramsSchema = ois.readObject().asInstanceOf[StructType] - (expr, paramsSchema) + val payload = ois.readObject().asInstanceOf[S with Serializable] + (expr, payload) } } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 5121b57ca..c86a3d24c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -15,28 +15,41 @@ */ package org.apache.spark.sql.blaze +import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import org.apache.arrow.c.{ArrowArray, Data} -import org.apache.arrow.vector.{IntVector, VectorSchemaRoot} +import scala.collection.mutable.ArrayBuffer +import org.apache.arrow.c.ArrowArray +import org.apache.arrow.c.Data +import org.apache.arrow.vector.IntVector +import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, ImperativeAggregate, TypedImperativeAggregate} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BindReferences, GenericInternalRow, JoinedRow, Nondeterministic, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.catalyst.expressions.Nondeterministic +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper -import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType, StructField, StructType} -import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter - -import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowWriter +import org.apache.spark.sql.execution.UnsafeRowSerializer +import org.apache.spark.sql.types.{BinaryType, IntegerType, ObjectType, StructField, StructType} +import org.apache.spark.util.ByteBufferInputStream case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { - private val (expr, javaParamsSchema) = - NativeConverters.deserializeExpression[AggregateFunction]({ + import org.apache.spark.sql.blaze.SparkUDAFWrapperContext._ + + private val (expr, List(javaParamsSchema, javaBufferSchema)) = + NativeConverters.deserializeExpression[AggregateFunction, List[StructType]]({ val bytes = new Array[Byte](serialized.remaining()) serialized.get(bytes) bytes @@ -46,6 +59,11 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { AttributeReference(field.name, field.dataType, field.nullable)() } + private val outputSchema = { + val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable))) + ArrowUtils.toArrowSchema(schema) + } + // initialize all nondeterministic children exprs expr.foreach { case nondeterministic: Nondeterministic => @@ -71,29 +89,13 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { toUnsafe } - private val outputSchema = { - val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable))) - ArrowUtils.toArrowSchema(schema) - } - - private val indexSchema = { - val schema = StructType(Seq(StructField("", IntegerType), StructField("", IntegerType))) - ArrowUtils.toArrowSchema(schema) - } - - private val evalIndexSchema = { - val schema = StructType(Seq(StructField("", IntegerType))) - ArrowUtils.toArrowSchema(schema) - } - - - def initialize(numRow: Int): ArrayBuffer[UnsafeRow] = { - val rows = ArrayBuffer[UnsafeRow]() + def initialize(numRow: Int): ArrayBuffer[InternalRow] = { + val rows = ArrayBuffer[InternalRow]() resize(rows, numRow) rows } - def resize(rows: ArrayBuffer[UnsafeRow], len: Int): Unit = { + def resize(rows: ArrayBuffer[InternalRow], len: Int): Unit = { if (rows.length < len) { rows.append(Range(rows.length, len).map(_ => aggEvaluator.initialize()): _*) } else { @@ -102,13 +104,13 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } def update( - rows: ArrayBuffer[UnsafeRow], + rows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long, importBatchFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( VectorSchemaRoot.create(inputSchema, batchAllocator), - VectorSchemaRoot.create(indexSchema, batchAllocator), + VectorSchemaRoot.create(indexTupleSchema, batchAllocator), ArrowArray.wrap(importBatchFFIArrayPtr), ArrowArray.wrap(importIdxFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => // import into params root @@ -131,12 +133,12 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } def merge( - rows: ArrayBuffer[UnsafeRow], - mergeRows: ArrayBuffer[UnsafeRow], + rows: ArrayBuffer[InternalRow], + mergeRows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( - VectorSchemaRoot.create(indexSchema, batchAllocator), + VectorSchemaRoot.create(indexTupleSchema, batchAllocator), ArrowArray.wrap(importIdxFFIArrayPtr)) { (idxRoot, idxArray) => Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) val fieldVectors = idxRoot.getFieldVectors.asScala @@ -155,12 +157,12 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } def eval( - rows: ArrayBuffer[UnsafeRow], + rows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long, exportFFIArrayPtr: Long): Unit = { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( - VectorSchemaRoot.create(evalIndexSchema, batchAllocator), + VectorSchemaRoot.create(indexSchema, batchAllocator), VectorSchemaRoot.create(outputSchema, batchAllocator), ArrowArray.wrap(importIdxFFIArrayPtr), ArrowArray.wrap(exportFFIArrayPtr)) { (idxRoot, outputRoot, idxArray, exportArray) => @@ -185,89 +187,190 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } } } + + def serializeRows( + rows: ArrayBuffer[InternalRow], + importFFIArrayPtr: Long, + exportFFIArrayPtr: Long): Unit = { + + Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => + Using.resources( + VectorSchemaRoot.create(serializedRowSchema, batchAllocator), + VectorSchemaRoot.create(indexSchema, batchAllocator), + ) { (exportDataRoot, importIdxRoot) => + + Using.resources( + ArrowArray.wrap(importFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { (importArray, exportArray) => + + // import into params root + Data.importIntoVectorSchemaRoot( + batchAllocator, + importArray, + importIdxRoot, + dictionaryProvider) + + // write serialized row into sequential raw bytes + val importIdxArray = importIdxRoot.getFieldVectors.get(0).asInstanceOf[IntVector] + val rowsIter = (0 until importIdxRoot.getRowCount).map(i => rows(importIdxArray.get(i))) + val serializedBytes = aggEvaluator.serializeRows(rowsIter) + + // export serialized data as a single row batch using root allocator + val outputWriter = ArrowWriter.create(exportDataRoot) + outputWriter.write(InternalRow(serializedBytes)) + outputWriter.finish() + Data.exportVectorSchemaRoot( + ArrowUtils.rootAllocator, + exportDataRoot, + dictionaryProvider, + exportArray) + } + } + } + } + + def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { + aggEvaluator.deserializeRows(dataBuffer) + } +} + +object SparkUDAFWrapperContext { + private val indexTupleSchema = { + val schema = StructType(Seq(StructField("", IntegerType), StructField("", IntegerType))) + ArrowUtils.toArrowSchema(schema) + } + + private val indexSchema = { + val schema = StructType(Seq(StructField("", IntegerType, nullable = false))) + ArrowUtils.toArrowSchema(schema) + } + + private val serializedRowSchema = { + val schema = StructType(Seq(StructField("", BinaryType, nullable = false))) + ArrowUtils.toArrowSchema(schema) + } } trait AggregateEvaluator extends Logging { - def initialize(): UnsafeRow - def update(mutableAggBuffer: UnsafeRow, row: UnsafeRow): UnsafeRow - def merge(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow - def eval(row: UnsafeRow): InternalRow + def initialize(): InternalRow + def update(mutableAggBuffer: InternalRow, row: InternalRow): InternalRow + def merge(row1: InternalRow, row2: InternalRow): InternalRow + def eval(row: InternalRow): InternalRow + def serializeRows(rows: Seq[InternalRow]): Array[Byte] + def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] } class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attribute]) extends AggregateEvaluator { - private lazy val initializer = UnsafeProjection.create(agg.initialValues) + private val initializer = UnsafeProjection.create(agg.initialValues) - private lazy val updater = + private val updater = UnsafeProjection.create(agg.updateExpressions, agg.aggBufferAttributes ++ inputAttributes) - private lazy val merger = UnsafeProjection.create( + private val merger = UnsafeProjection.create( agg.mergeExpressions, agg.aggBufferAttributes ++ agg.inputAggBufferAttributes) - private lazy val evaluator = + private val evaluator = UnsafeProjection.create(agg.evaluateExpression :: Nil, agg.aggBufferAttributes) - new MapDictionaryProvider() - override def initialize(): UnsafeRow = { + override def initialize(): InternalRow = { initializer.apply(InternalRow.empty) } - override def update(mutableAggBuffer: UnsafeRow, row: UnsafeRow): UnsafeRow = { + override def update(mutableAggBuffer: InternalRow, row: InternalRow): InternalRow = { updater(new JoinedRow(mutableAggBuffer, row)).copy() } - override def merge(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow = { + override def merge(row1: InternalRow, row2: InternalRow): InternalRow = { merger(new JoinedRow(row1, row2)).copy() } - override def eval(row: UnsafeRow): InternalRow = { + override def eval(row: InternalRow): InternalRow = { evaluator(row) } + + override def serializeRows(rows: Seq[InternalRow]): Array[Byte] = { + val numFields = agg.aggBufferSchema.length + val outputDataStream = new ByteArrayOutputStream() + val serializer = new UnsafeRowSerializer(numFields).newInstance() + + Using(serializer.serializeStream(outputDataStream)) { ser => + for (row <- rows) { + ser.writeValue(row) + } + } + outputDataStream.toByteArray + } + + override def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { + val numFields = agg.aggBufferSchema.length + val deserializer = new UnsafeRowSerializer(numFields).newInstance() + val inputDataStream = new ByteBufferInputStream(dataBuffer) + val rows = new ArrayBuffer[InternalRow]() + + Using.resource(deserializer.deserializeStream(inputDataStream)) { deser => + for (row <- deser.asKeyValueIterator.map(_._2.asInstanceOf[UnsafeRow].copy())) { + rows.append(row) + } + } + rows + } } class TypedImperativeEvaluator[T](agg: TypedImperativeAggregate[T], inputAttributes: Seq[Attribute]) extends AggregateEvaluator { private val bufferSchema = agg.aggBufferAttributes.map(_.dataType) + private val anyObjectType = ObjectType(classOf[AnyRef]) - private def getBufferObject(bufferRow: UnsafeRow): T = { - agg.deserialize(bufferRow.getBytes) + private def getBufferObject(buffer: InternalRow): T = { + buffer.get(0, anyObjectType).asInstanceOf[T] } - - - override def initialize(): UnsafeRow = { -// val byteBuffer = agg.serialize(agg.createAggregationBuffer()) -// val unsafeRow = new UnsafeRow(bufferSchema.length) -// unsafeRow.pointTo(byteBuffer, byteBuffer.length) -// logInfo(s"bufferSchema: $bufferSchema") - val buffer = new SpecificInternalRow(bufferSchema) - agg.initialize(buffer) -// logInfo(s"buffer $buffer") - val writer = new UnsafeRowWriter(1) - writer.write(0, buffer.getBinary(0)) - val unsafeRow = writer.getRow -// logInfo(s"init unsaferow $unsafeRow") - unsafeRow + override def initialize(): InternalRow = { + val row = InternalRow(bufferSchema) + agg.initialize(row) + row } - override def update(buffer: UnsafeRow, row: UnsafeRow): UnsafeRow = { - val bufferObject = agg.update(getBufferObject(buffer), row) - val byteBuffer = agg.serialize(bufferObject).clone() - buffer.pointTo(byteBuffer, byteBuffer.length) + override def update(buffer: InternalRow, row: InternalRow): InternalRow = { + agg.update(buffer, row) buffer } - override def merge(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow = { - val bufferObject = agg.merge(getBufferObject(row1), getBufferObject(row2)) - val byteBuffer = agg.serialize(bufferObject).clone() - row1.pointTo(byteBuffer, byteBuffer.length) + override def merge(row1: InternalRow, row2: InternalRow): InternalRow = { + val Object1 = getBufferObject(row1) + val Object2 = getBufferObject(row2) + row1.update(0, agg.merge(Object1, Object2)) row1 } - override def eval(row: UnsafeRow): InternalRow = { - InternalRow(agg.eval(getBufferObject(row))) + override def eval(row: InternalRow): InternalRow = { + InternalRow(agg.eval(row)) + } + + override def serializeRows(rows: Seq[InternalRow]): Array[Byte] = { + val outputStream = new ByteArrayOutputStream() + val dataOut = new DataOutputStream(outputStream) + for (row <- rows) { + val byteBuffer = agg.serialize(row.get(0, anyObjectType).asInstanceOf[T]) + dataOut.writeInt(byteBuffer.length) + outputStream.write(byteBuffer) + } + outputStream.toByteArray + } + + override def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { + val rows = ArrayBuffer[InternalRow]() + while (dataBuffer.hasRemaining) { + val length = dataBuffer.getInt() + val byteBuffer = new Array[Byte](length) + dataBuffer.get(byteBuffer) + val row = InternalRow(agg.deserialize(byteBuffer)) + rows.append(row) + } + rows } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala index 339386f0e..60ecfd103 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType case class SparkUDFWrapperContext(serialized: ByteBuffer) extends Logging { - private val (expr, javaParamsSchema) = NativeConverters.deserializeExpression[Expression]({ + private val (expr, javaParamsSchema) = NativeConverters.deserializeExpression[Expression, StructType]({ val bytes = new Array[Byte](serialized.remaining()) serialized.get(bytes) bytes diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala index 14f5032c7..dc387ddef 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType case class SparkUDTFWrapperContext(serialized: ByteBuffer) extends Logging { - private val (expr, javaParamsSchema) = NativeConverters.deserializeExpression[Generator]({ + private val (expr, javaParamsSchema) = NativeConverters.deserializeExpression[Generator, StructType]({ val bytes = new Array[Byte](serialized.remaining()) serialized.get(bytes) bytes diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala index 448b0cbc9..79fc2d37a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala @@ -46,7 +46,7 @@ object UnsafeRowsWrapper extends Logging { ArrowUtils.toArrowSchema(schema) } - private val dataSchema = { + private val serializedRowSchema = { val schema = StructType(Seq(StructField("", BinaryType, nullable = false))) ArrowUtils.toArrowSchema(schema) } @@ -84,7 +84,7 @@ object UnsafeRowsWrapper extends Logging { Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( - VectorSchemaRoot.create(dataSchema, batchAllocator), + VectorSchemaRoot.create(serializedRowSchema, batchAllocator), VectorSchemaRoot.create(idxSchema, batchAllocator), ) { (exportDataRoot, importIdxRoot) => From 07e0107f13b188789eac0175e865498bd5599896 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Thu, 20 Feb 2025 10:29:58 +0800 Subject: [PATCH 12/17] add spill --- .../blaze-jni-bridge/src/jni_bridge.rs | 8 +++ .../src/agg/spark_udaf_wrapper.rs | 51 +++++++++++++++++-- .../sql/blaze/SparkUDAFWrapperContext.scala | 18 +++++++ 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index 698a1ce85..e8c857602 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -1189,6 +1189,8 @@ pub struct SparkUDAFWrapperContext<'a> { pub method_serializeRows_ret: ReturnType, pub method_deserializeRows: JMethodID, pub method_deserializeRows_ret: ReturnType, + pub method_memUsed: JMethodID, + pub method_memUsed_ret: ReturnType, } impl<'a> SparkUDAFWrapperContext<'a> { pub const SIG_TYPE: &'static str = "org/apache/spark/sql/blaze/SparkUDAFWrapperContext"; @@ -1240,6 +1242,12 @@ impl<'a> SparkUDAFWrapperContext<'a> { "(Ljava/nio/ByteBuffer;)Lscala/collection/mutable/ArrayBuffer;", )?, method_deserializeRows_ret: ReturnType::Object, + method_memUsed: env.get_method_id( + class, + "memUsed", + "(Lscala/collection/mutable/ArrayBuffer;)I", + )?, + method_memUsed_ret: ReturnType::Primitive(Primitive::Int), }) } } diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index e9bcf8eac..63967f852 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -18,6 +18,7 @@ use std::{ io::Cursor, sync::Arc, }; +use std::io::Write; use arrow::{ array::{ @@ -37,6 +38,7 @@ use datafusion_ext_commons::{ }; use jni::objects::{GlobalRef, JObject}; use once_cell::sync::OnceCell; +use datafusion_ext_commons::io::{read_bytes_into_vec, read_bytes_slice}; use crate::{ agg::{ @@ -268,7 +270,10 @@ impl AccColumn for AccUnsafeRowsColumn { } fn mem_used(&self) -> usize { - 0 + jni_call!( + SparkUDAFWrapperContext(self.jcontext.as_obj()).memUsed( + self.obj.as_obj()) + -> i32).unwrap() as usize } fn freeze_to_rows(&self, idx: IdxSelection<'_>, array: &mut [Vec]) -> Result<()> { @@ -326,12 +331,52 @@ impl AccColumn for AccUnsafeRowsColumn { } fn spill(&self, idx: IdxSelection<'_>, buf: &mut SpillCompressedWriter) -> Result<()> { - unimplemented!() + log::info!("start spill!"); + let idx_array: ArrayRef = Arc::new(idx.to_int32_array()); + let struct_array = + StructArray::from(RecordBatch::try_new(index_schema(), vec![idx_array])?); + let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); + let mut import_ffi_array = FFI_ArrowArray::empty(); + jni_call!( + SparkUDAFWrapperContext(self.jcontext.as_obj()).serializeRows( + self.obj.as_obj(), + &mut export_ffi_array as *mut FFI_ArrowArray as i64, + &mut import_ffi_array as *mut FFI_ArrowArray as i64,) + -> ())?; + // import output from context + let import_ffi_schema = FFI_ArrowSchema::try_from(serialized_row_schema().as_ref())?; + let import_struct_array = + make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); + let result_struct = import_struct_array.as_struct(); + + let binary_array = downcast_any!(result_struct.column(0), BinaryArray)?; + let data = binary_array.value(0); + buf.write(data)?; + log::info!("end spill!"); + Ok(()) } fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> { - unimplemented!() + log::info!("start unspill!"); + let mut data = vec![]; + let mut data_len = 0; + for i in 0.. num_rows { + let bytes_len = i32::from_be_bytes(data[data_len..][..4].try_into().unwrap()) as usize; + data_len += bytes_len + 4; + } + let mut data = vec![]; + read_bytes_into_vec(r, &mut data, data_len)?; + + let data_buffer = jni_new_direct_byte_buffer!(data)?; + let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()) + .deserializeRows(data_buffer.as_obj()) -> JObject)?; + self.obj = jni_new_global_ref!(rows.as_obj())?; + self.num_rows = num_rows; + + log::info!("start unspill!"); + Ok(()) } + } fn int32_field() -> FieldRef { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index c86a3d24c..8d3f21549 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -232,6 +232,10 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { aggEvaluator.deserializeRows(dataBuffer) } + + def memUsed(rows: ArrayBuffer[InternalRow]): Int = { + aggEvaluator.memUsed(rows) + } } object SparkUDAFWrapperContext { @@ -258,6 +262,8 @@ trait AggregateEvaluator extends Logging { def eval(row: InternalRow): InternalRow def serializeRows(rows: Seq[InternalRow]): Array[Byte] def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] + + def memUsed(rows: Seq[InternalRow]): Int } class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attribute]) @@ -318,6 +324,14 @@ class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attri } rows } + + override def memUsed(rows: Seq[InternalRow]): Int = { + var mem = 0 + for (row <- rows) { + mem = mem + row.asInstanceOf[UnsafeRow].getSizeInBytes + } + mem + } } class TypedImperativeEvaluator[T](agg: TypedImperativeAggregate[T], inputAttributes: Seq[Attribute]) @@ -373,4 +387,8 @@ class TypedImperativeEvaluator[T](agg: TypedImperativeAggregate[T], inputAttribu } rows } + + override def memUsed(rows: Seq[InternalRow]): Int = { + rows.length * 192 + } } From a3b4377d501cb3b94f4dba86395fddb2a9718652 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Thu, 20 Feb 2025 11:53:21 +0800 Subject: [PATCH 13/17] fix conflict --- .../src/agg/spark_udaf_wrapper.rs | 2 +- .../sql/blaze/UnsafeRowsWrapperUtils.java | 32 ---- .../sql/blaze/SparkUDAFWrapperContext.scala | 155 +++++++++--------- .../spark/sql/blaze/UnsafeRowsWrapper.scala | 140 ---------------- .../blaze/columnar/ColumnarHelper.scala | 13 ++ 5 files changed, 89 insertions(+), 253 deletions(-) delete mode 100644 spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java delete mode 100644 spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index 63967f852..0882a9987 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -38,7 +38,7 @@ use datafusion_ext_commons::{ }; use jni::objects::{GlobalRef, JObject}; use once_cell::sync::OnceCell; -use datafusion_ext_commons::io::{read_bytes_into_vec, read_bytes_slice}; +use datafusion_ext_commons::io::read_bytes_into_vec; use crate::{ agg::{ diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java deleted file mode 100644 index 025b774fc..000000000 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/UnsafeRowsWrapperUtils.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.blaze; - -import java.nio.ByteBuffer; -import org.apache.spark.sql.catalyst.InternalRow; -import scala.collection.mutable.ArrayBuffer; - -public class UnsafeRowsWrapperUtils { - - public static void serialize( - ArrayBuffer unsafeRows, int numFields, long importFFIArrayPtr, long exportFFIArrayPtr) { - UnsafeRowsWrapper$.MODULE$.serialize(unsafeRows, numFields, importFFIArrayPtr, exportFFIArrayPtr); - } - - public static ArrayBuffer deserialize(int numFields, ByteBuffer dataBuffer) { - return UnsafeRowsWrapper$.MODULE$.deserialize(numFields, dataBuffer); - } -} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 8d3f21549..a9a3b8239 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -38,10 +38,11 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowWriter import org.apache.spark.sql.execution.UnsafeRowSerializer +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils.ROOT_ALLOCATOR +import org.apache.spark.sql.execution.blaze.columnar.ColumnarHelper import org.apache.spark.sql.types.{BinaryType, IntegerType, ObjectType, StructField, StructType} import org.apache.spark.util.ByteBufferInputStream @@ -107,17 +108,16 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { rows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long, importBatchFFIArrayPtr: Long): Unit = { - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( - VectorSchemaRoot.create(inputSchema, batchAllocator), - VectorSchemaRoot.create(indexTupleSchema, batchAllocator), + VectorSchemaRoot.create(inputSchema, ROOT_ALLOCATOR), + VectorSchemaRoot.create(indexTupleSchema, ROOT_ALLOCATOR), ArrowArray.wrap(importBatchFFIArrayPtr), ArrowArray.wrap(importIdxFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => // import into params root - Data.importIntoVectorSchemaRoot(batchAllocator, inputArray, inputRoot, dictionaryProvider) - val inputRows = ColumnarHelper.rootAsBatch(inputRoot) + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, inputArray, inputRoot, dictionaryProvider) + val inputRows = ColumnarHelper.rootRowsArray(inputRoot) - Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, idxArray, idxRoot, dictionaryProvider) val fieldVectors = idxRoot.getFieldVectors.asScala val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] val inputIdxVector = fieldVectors(1).asInstanceOf[IntVector] @@ -125,10 +125,9 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { for (i <- 0 until idxRoot.getRowCount) { val rowIdx = rowIdxVector.get(i) val row = rows(rowIdx) - val input = paramsToUnsafe(inputRows.getRow(inputIdxVector.get(i))) + val input = paramsToUnsafe(inputRows(inputIdxVector.get(i))) rows(rowIdx) = aggEvaluator.update(row, input) } - } } } @@ -136,97 +135,93 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { rows: ArrayBuffer[InternalRow], mergeRows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long): Unit = { - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resources( - VectorSchemaRoot.create(indexTupleSchema, batchAllocator), - ArrowArray.wrap(importIdxFFIArrayPtr)) { (idxRoot, idxArray) => - Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) - val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] - val mergeIdxVector = fieldVectors(1).asInstanceOf[IntVector] - - for (i <- 0 until idxRoot.getRowCount) { - val rowIdx = rowIdxVector.get(i) - val mergeIdx = mergeIdxVector.get(i) - val row = rows(rowIdx) - val mergeRow = mergeRows(mergeIdx) - rows(rowIdx) = aggEvaluator.merge(row, mergeRow) - } + Using.resources( + VectorSchemaRoot.create(indexTupleSchema, ROOT_ALLOCATOR), + ArrowArray.wrap(importIdxFFIArrayPtr)) { (idxRoot, idxArray) => + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, idxArray, idxRoot, dictionaryProvider) + val fieldVectors = idxRoot.getFieldVectors.asScala + val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] + val mergeIdxVector = fieldVectors(1).asInstanceOf[IntVector] + + for (i <- 0 until idxRoot.getRowCount) { + val rowIdx = rowIdxVector.get(i) + val mergeIdx = mergeIdxVector.get(i) + val row = rows(rowIdx) + val mergeRow = mergeRows(mergeIdx) + rows(rowIdx) = aggEvaluator.merge(row, mergeRow) } } } + def eval( rows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long, exportFFIArrayPtr: Long): Unit = { - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resources( - VectorSchemaRoot.create(indexSchema, batchAllocator), - VectorSchemaRoot.create(outputSchema, batchAllocator), - ArrowArray.wrap(importIdxFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { (idxRoot, outputRoot, idxArray, exportArray) => - Data.importIntoVectorSchemaRoot(batchAllocator, idxArray, idxRoot, dictionaryProvider) - val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] - - // evaluate expression and write to output root - val outputWriter = ArrowWriter.create(outputRoot) - for (i <- 0 until idxRoot.getRowCount) { - val row = rows(rowIdxVector.get(i)) - outputWriter.write(aggEvaluator.eval(row)) - } - outputWriter.finish() - - // export to output using root allocator - Data.exportVectorSchemaRoot( - ArrowUtils.rootAllocator, - outputRoot, - dictionaryProvider, - exportArray) + Using.resources( + VectorSchemaRoot.create(indexSchema, ROOT_ALLOCATOR), + VectorSchemaRoot.create(outputSchema, ROOT_ALLOCATOR), + ArrowArray.wrap(importIdxFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { (idxRoot, outputRoot, idxArray, exportArray) => + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, idxArray, idxRoot, dictionaryProvider) + val fieldVectors = idxRoot.getFieldVectors.asScala + val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] + + // evaluate expression and write to output root + val outputWriter = ArrowWriter.create(outputRoot) + for (i <- 0 until idxRoot.getRowCount) { + val row = rows(rowIdxVector.get(i)) + outputWriter.write(aggEvaluator.eval(row)) } + outputWriter.finish() + + // export to output using root allocator + Data.exportVectorSchemaRoot( + ROOT_ALLOCATOR, + outputRoot, + dictionaryProvider, + exportArray) } } + def serializeRows( rows: ArrayBuffer[InternalRow], importFFIArrayPtr: Long, exportFFIArrayPtr: Long): Unit = { + Using.resources( + VectorSchemaRoot.create(serializedRowSchema, ROOT_ALLOCATOR), + VectorSchemaRoot.create(indexSchema, ROOT_ALLOCATOR), + ) { (exportDataRoot, importIdxRoot) => - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => Using.resources( - VectorSchemaRoot.create(serializedRowSchema, batchAllocator), - VectorSchemaRoot.create(indexSchema, batchAllocator), - ) { (exportDataRoot, importIdxRoot) => - - Using.resources( - ArrowArray.wrap(importFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { (importArray, exportArray) => - - // import into params root - Data.importIntoVectorSchemaRoot( - batchAllocator, - importArray, - importIdxRoot, - dictionaryProvider) - - // write serialized row into sequential raw bytes - val importIdxArray = importIdxRoot.getFieldVectors.get(0).asInstanceOf[IntVector] - val rowsIter = (0 until importIdxRoot.getRowCount).map(i => rows(importIdxArray.get(i))) - val serializedBytes = aggEvaluator.serializeRows(rowsIter) - - // export serialized data as a single row batch using root allocator - val outputWriter = ArrowWriter.create(exportDataRoot) - outputWriter.write(InternalRow(serializedBytes)) - outputWriter.finish() - Data.exportVectorSchemaRoot( - ArrowUtils.rootAllocator, - exportDataRoot, - dictionaryProvider, - exportArray) - } + ArrowArray.wrap(importFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { (importArray, exportArray) => + + // import into params root + Data.importIntoVectorSchemaRoot( + ROOT_ALLOCATOR, + importArray, + importIdxRoot, + dictionaryProvider) + + // write serialized row into sequential raw bytes + val importIdxArray = importIdxRoot.getFieldVectors.get(0).asInstanceOf[IntVector] + val rowsIter = (0 until importIdxRoot.getRowCount).map(i => rows(importIdxArray.get(i))) + val serializedBytes = aggEvaluator.serializeRows(rowsIter) + + // export serialized data as a single row batch using root allocator + val outputWriter = ArrowWriter.create(exportDataRoot) + outputWriter.write(InternalRow(serializedBytes)) + outputWriter.finish() + Data.exportVectorSchemaRoot( + ROOT_ALLOCATOR, + exportDataRoot, + dictionaryProvider, + exportArray) } } + } def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala deleted file mode 100644 index 79fc2d37a..000000000 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/UnsafeRowsWrapper.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.blaze - -import org.apache.arrow.c.{ArrowArray, Data} -import org.apache.arrow.vector.{IntVector, VarBinaryVector, VectorSchemaRoot} -import org.apache.arrow.vector.dictionary.DictionaryProvider -import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.util.Utils -import org.apache.spark.internal.Logging -import org.apache.spark.sql.blaze.util.Using -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.UnsafeRowSerializer -import org.apache.spark.sql.execution.blaze.arrowio.util.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StructField, StructType} -import org.apache.spark.sql.Row -import scala.collection.JavaConverters._ -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import java.io.OutputStream -import java.nio.ByteBuffer -import java.nio.ByteOrder - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.util.ByteBufferInputStream - -object UnsafeRowsWrapper extends Logging { - - private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() - private val idxSchema = { - val schema = StructType(Seq(StructField("", IntegerType, nullable = false))) - ArrowUtils.toArrowSchema(schema) - } - - private val serializedRowSchema = { - val schema = StructType(Seq(StructField("", BinaryType, nullable = false))) - ArrowUtils.toArrowSchema(schema) - } - - private val deserializeSchema = { - val schema = StructType( - Seq( - StructField("", BinaryType, nullable = false), - StructField("", IntegerType, nullable = false))) - ArrowUtils.toArrowSchema(schema) - } - - private val offsetSchema = { - val schema = StructType(Seq(StructField("", IntegerType, nullable = false))) - ArrowUtils.toArrowSchema(schema) - } - - private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { - val converter = unsafeRowConverter(schema) - converter(row) - } - - private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = { - val converter = UnsafeProjection.create(schema) - (row: Row) => { - converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]) - } - } - - def serialize( - rows: ArrayBuffer[InternalRow], - numFields: Int, - importFFIArrayPtr: Long, - exportFFIArrayPtr: Long): Unit = { - - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resources( - VectorSchemaRoot.create(serializedRowSchema, batchAllocator), - VectorSchemaRoot.create(idxSchema, batchAllocator), - ) { (exportDataRoot, importIdxRoot) => - - Using.resources( - ArrowArray.wrap(importFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { (importArray, exportArray) => - - // import into params root - Data.importIntoVectorSchemaRoot( - batchAllocator, - importArray, - importIdxRoot, - dictionaryProvider) - - // write serialized row into sequential raw bytes - val importIdxArray = importIdxRoot.getFieldVectors.get(0).asInstanceOf[IntVector] - val outputDataStream = new ByteArrayOutputStream() - val serializer = new UnsafeRowSerializer(numFields).newInstance() - Using(serializer.serializeStream(outputDataStream)) { ser => - for (idx <- 0 until importIdxRoot.getRowCount) { - val rowIdx = importIdxArray.get(idx) - val row = rows(rowIdx) - ser.writeValue(row) - } - } - - // export serialized data as a single row batch using root allocator - val outputWriter = ArrowWriter.create(exportDataRoot) - outputWriter.write(InternalRow(outputDataStream.toByteArray)) - outputWriter.finish() - Data.exportVectorSchemaRoot( - ArrowUtils.rootAllocator, - exportDataRoot, - dictionaryProvider, - exportArray) - } - } - } - } - - def deserialize(numFields: Int, dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { - val deserializer = new UnsafeRowSerializer(numFields).newInstance() - val inputDataStream = new ByteBufferInputStream(dataBuffer) - val rows = new ArrayBuffer[InternalRow]() - - Using.resource(deserializer.deserializeStream(inputDataStream)) { deser => - for (row <- deser.asKeyValueIterator.map(_._2.asInstanceOf[UnsafeRow].copy())) { - rows.append(row) - } - } - rows - } -} \ No newline at end of file diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala index 640cadfba..e032bad72 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala @@ -32,4 +32,17 @@ object ColumnarHelper { row.asInstanceOf[InternalRow] } } + + def rootRowsArray(root: VectorSchemaRoot): Array[InternalRow] = { + val vectors = root.getFieldVectors.asScala.toArray + val numRows = root.getRowCount + val row = new BlazeColumnarBatchRow( + vectors.map(new BlazeArrowColumnVector(_).asInstanceOf[BlazeColumnVector]) + ) + (0 until numRows).map { rowId => + row.rowId = rowId + row.asInstanceOf[InternalRow] + }.toArray + } + } From dec7895add11b2a6db00659774b7b76404f3c5da Mon Sep 17 00:00:00 2001 From: guoying06 Date: Thu, 20 Feb 2025 16:56:39 +0800 Subject: [PATCH 14/17] fix spill --- .../src/agg/spark_udaf_wrapper.rs | 25 ++-- .../spark/sql/blaze/NativeConverters.scala | 3 +- .../sql/blaze/SparkUDAFWrapperContext.scala | 117 ++++++++---------- .../blaze/columnar/ColumnarHelper.scala | 3 +- 4 files changed, 63 insertions(+), 85 deletions(-) diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index 0882a9987..222875f84 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -15,10 +15,9 @@ use std::{ any::Any, fmt::{Debug, Display, Formatter}, - io::Cursor, + io::{Cursor, Write}, sync::Arc, }; -use std::io::Write; use arrow::{ array::{ @@ -34,11 +33,11 @@ use blaze_jni_bridge::{jni_call, jni_new_direct_byte_buffer, jni_new_global_ref, use datafusion::{common::Result, physical_expr::PhysicalExpr}; use datafusion_ext_commons::{ downcast_any, - io::{read_len, write_len}, + io::{read_bytes_into_vec, read_len, write_len}, }; use jni::objects::{GlobalRef, JObject}; use once_cell::sync::OnceCell; -use datafusion_ext_commons::io::read_bytes_into_vec; +use datafusion_ext_commons::io::read_bytes_slice; use crate::{ agg::{ @@ -331,7 +330,6 @@ impl AccColumn for AccUnsafeRowsColumn { } fn spill(&self, idx: IdxSelection<'_>, buf: &mut SpillCompressedWriter) -> Result<()> { - log::info!("start spill!"); let idx_array: ArrayRef = Arc::new(idx.to_int32_array()); let struct_array = StructArray::from(RecordBatch::try_new(index_schema(), vec![idx_array])?); @@ -352,31 +350,24 @@ impl AccColumn for AccUnsafeRowsColumn { let binary_array = downcast_any!(result_struct.column(0), BinaryArray)?; let data = binary_array.value(0); buf.write(data)?; - log::info!("end spill!"); Ok(()) } fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> { - log::info!("start unspill!"); let mut data = vec![]; - let mut data_len = 0; - for i in 0.. num_rows { - let bytes_len = i32::from_be_bytes(data[data_len..][..4].try_into().unwrap()) as usize; - data_len += bytes_len + 4; + for i in 0..num_rows { + let bytes_len = read_bytes_slice(r, 4)?; + let length = i32::from_be_bytes(bytes_len.as_ref().try_into().unwrap()); + data.extend_from_slice(bytes_len.as_ref()); + data.extend_from_slice( read_bytes_slice(r, length as usize)?.as_ref()); } - let mut data = vec![]; - read_bytes_into_vec(r, &mut data, data_len)?; - let data_buffer = jni_new_direct_byte_buffer!(data)?; let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()) .deserializeRows(data_buffer.as_obj()) -> JObject)?; self.obj = jni_new_global_ref!(rows.as_obj())?; self.num_rows = num_rows; - - log::info!("start unspill!"); Ok(()) } - } fn int32_field() -> FieldRef { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index 7b3d933c0..b679971f8 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -1152,7 +1152,8 @@ object NativeConverters extends Logging { // other udaf aggFunction case udaf if classOf[DeclarativeAggregate].isAssignableFrom(e.aggregateFunction.getClass) - || classOf[TypedImperativeAggregate[_]].isAssignableFrom(e.aggregateFunction.getClass) => + || classOf[TypedImperativeAggregate[_]].isAssignableFrom( + e.aggregateFunction.getClass) => aggBuilder.setAggFunction(pb.AggFunction.DECLARATIVE) val convertedChildren = mutable.LinkedHashMap[pb.PhysicalExprNode, BoundReference]() diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index a9a3b8239..36229e45c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -74,11 +74,9 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { private val aggEvaluator = expr match { case declarative: DeclarativeAggregate => - logInfo(s"init DeclarativeEvaluator") new DeclarativeEvaluator(declarative, inputAttributes) case imperative: TypedImperativeAggregate[_] => - logInfo(s"init TypedImperativeEvaluator") - new TypedImperativeEvaluator(imperative, inputAttributes) + new TypedImperativeEvaluator(imperative) } private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() @@ -108,26 +106,26 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { rows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long, importBatchFFIArrayPtr: Long): Unit = { - Using.resources( - VectorSchemaRoot.create(inputSchema, ROOT_ALLOCATOR), - VectorSchemaRoot.create(indexTupleSchema, ROOT_ALLOCATOR), - ArrowArray.wrap(importBatchFFIArrayPtr), - ArrowArray.wrap(importIdxFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => - // import into params root - Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, inputArray, inputRoot, dictionaryProvider) - val inputRows = ColumnarHelper.rootRowsArray(inputRoot) - - Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, idxArray, idxRoot, dictionaryProvider) - val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] - val inputIdxVector = fieldVectors(1).asInstanceOf[IntVector] - - for (i <- 0 until idxRoot.getRowCount) { - val rowIdx = rowIdxVector.get(i) - val row = rows(rowIdx) - val input = paramsToUnsafe(inputRows(inputIdxVector.get(i))) - rows(rowIdx) = aggEvaluator.update(row, input) - } + Using.resources( + VectorSchemaRoot.create(inputSchema, ROOT_ALLOCATOR), + VectorSchemaRoot.create(indexTupleSchema, ROOT_ALLOCATOR), + ArrowArray.wrap(importBatchFFIArrayPtr), + ArrowArray.wrap(importIdxFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => + // import into params root + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, inputArray, inputRoot, dictionaryProvider) + val inputRows = ColumnarHelper.rootRowsArray(inputRoot) + + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, idxArray, idxRoot, dictionaryProvider) + val fieldVectors = idxRoot.getFieldVectors.asScala + val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] + val inputIdxVector = fieldVectors(1).asInstanceOf[IntVector] + + for (i <- 0 until idxRoot.getRowCount) { + val rowIdx = rowIdxVector.get(i) + val row = rows(rowIdx) + val input = paramsToUnsafe(inputRows(inputIdxVector.get(i))) + rows(rowIdx) = aggEvaluator.update(row, input) + } } } @@ -153,7 +151,6 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } } - def eval( rows: ArrayBuffer[InternalRow], importIdxFFIArrayPtr: Long, @@ -176,49 +173,40 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { outputWriter.finish() // export to output using root allocator - Data.exportVectorSchemaRoot( - ROOT_ALLOCATOR, - outputRoot, - dictionaryProvider, - exportArray) + Data.exportVectorSchemaRoot(ROOT_ALLOCATOR, outputRoot, dictionaryProvider, exportArray) } } - def serializeRows( - rows: ArrayBuffer[InternalRow], - importFFIArrayPtr: Long, - exportFFIArrayPtr: Long): Unit = { + rows: ArrayBuffer[InternalRow], + importFFIArrayPtr: Long, + exportFFIArrayPtr: Long): Unit = { Using.resources( VectorSchemaRoot.create(serializedRowSchema, ROOT_ALLOCATOR), - VectorSchemaRoot.create(indexSchema, ROOT_ALLOCATOR), - ) { (exportDataRoot, importIdxRoot) => - - Using.resources( - ArrowArray.wrap(importFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { (importArray, exportArray) => - - // import into params root - Data.importIntoVectorSchemaRoot( - ROOT_ALLOCATOR, - importArray, - importIdxRoot, - dictionaryProvider) - - // write serialized row into sequential raw bytes - val importIdxArray = importIdxRoot.getFieldVectors.get(0).asInstanceOf[IntVector] - val rowsIter = (0 until importIdxRoot.getRowCount).map(i => rows(importIdxArray.get(i))) - val serializedBytes = aggEvaluator.serializeRows(rowsIter) - - // export serialized data as a single row batch using root allocator - val outputWriter = ArrowWriter.create(exportDataRoot) - outputWriter.write(InternalRow(serializedBytes)) - outputWriter.finish() - Data.exportVectorSchemaRoot( - ROOT_ALLOCATOR, - exportDataRoot, - dictionaryProvider, - exportArray) + VectorSchemaRoot.create(indexSchema, ROOT_ALLOCATOR)) { (exportDataRoot, importIdxRoot) => + Using.resources(ArrowArray.wrap(importFFIArrayPtr), ArrowArray.wrap(exportFFIArrayPtr)) { + (importArray, exportArray) => + // import into params root + Data.importIntoVectorSchemaRoot( + ROOT_ALLOCATOR, + importArray, + importIdxRoot, + dictionaryProvider) + + // write serialized row into sequential raw bytes + val importIdxArray = importIdxRoot.getFieldVectors.get(0).asInstanceOf[IntVector] + val rowsIter = (0 until importIdxRoot.getRowCount).map(i => rows(importIdxArray.get(i))) + val serializedBytes = aggEvaluator.serializeRows(rowsIter) + + // export serialized data as a single row batch using root allocator + val outputWriter = ArrowWriter.create(exportDataRoot) + outputWriter.write(InternalRow(serializedBytes)) + outputWriter.finish() + Data.exportVectorSchemaRoot( + ROOT_ALLOCATOR, + exportDataRoot, + dictionaryProvider, + exportArray) } } @@ -257,7 +245,6 @@ trait AggregateEvaluator extends Logging { def eval(row: InternalRow): InternalRow def serializeRows(rows: Seq[InternalRow]): Array[Byte] def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] - def memUsed(rows: Seq[InternalRow]): Int } @@ -276,7 +263,6 @@ class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attri private val evaluator = UnsafeProjection.create(agg.evaluateExpression :: Nil, agg.aggBufferAttributes) - override def initialize(): InternalRow = { initializer.apply(InternalRow.empty) } @@ -329,10 +315,11 @@ class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attri } } -class TypedImperativeEvaluator[T](agg: TypedImperativeAggregate[T], inputAttributes: Seq[Attribute]) +class TypedImperativeEvaluator[T]( + agg: TypedImperativeAggregate[T]) extends AggregateEvaluator { - private val bufferSchema = agg.aggBufferAttributes.map(_.dataType) + private val bufferSchema = agg.aggBufferAttributes.map(_.dataType) private val anyObjectType = ObjectType(classOf[AnyRef]) private def getBufferObject(buffer: InternalRow): T = { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala index e032bad72..c5b9f53d0 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala @@ -37,8 +37,7 @@ object ColumnarHelper { val vectors = root.getFieldVectors.asScala.toArray val numRows = root.getRowCount val row = new BlazeColumnarBatchRow( - vectors.map(new BlazeArrowColumnVector(_).asInstanceOf[BlazeColumnVector]) - ) + vectors.map(new BlazeArrowColumnVector(_).asInstanceOf[BlazeColumnVector])) (0 until numRows).map { rowId => row.rowId = rowId row.asInstanceOf[InternalRow] From 9d881684bed90d64ca0bf6b374dc7c37fda8e1f7 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Tue, 25 Feb 2025 15:03:42 +0800 Subject: [PATCH 15/17] optimize --- .../blaze-jni-bridge/src/jni_bridge.rs | 50 ++- native-engine/blaze-serde/proto/blaze.proto | 10 +- native-engine/blaze-serde/src/from_proto.rs | 26 +- native-engine/blaze-serde/src/lib.rs | 2 +- .../datafusion-ext-plans/src/agg/agg.rs | 25 +- .../datafusion-ext-plans/src/agg/mod.rs | 2 +- .../src/agg/spark_udaf_wrapper.rs | 279 ++++----------- .../src/ipc_reader_exec.rs | 6 +- .../spark/sql/blaze/NativeConverters.scala | 9 +- .../sql/blaze/SparkUDAFWrapperContext.scala | 325 ++++++++---------- .../blaze/columnar/ColumnarHelper.scala | 20 +- 11 files changed, 294 insertions(+), 460 deletions(-) diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index e8c857602..a87e752a2 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -228,7 +228,7 @@ macro_rules! jni_get_byte_array_region { $crate::jni_bridge::THREAD_JNIENV.with(|env| { $crate::jni_map_error_with_env!( env, - env.get_byte_array_region($value, $start as i32, unsafe { + env.get_byte_array_region($value.cast(), $start as i32, unsafe { std::mem::transmute::<_, &mut [i8]>($buf) }) ) @@ -236,6 +236,36 @@ macro_rules! jni_get_byte_array_region { }}; } +#[macro_export] +macro_rules! jni_get_byte_array_len { + ($value:expr) => {{ + $crate::jni_bridge::THREAD_JNIENV.with(|env| { + $crate::jni_map_error_with_env!( + env, + env.get_array_length($value.cast()).map(|s| s as usize) + ) + }) + }}; +} + +#[macro_export] +macro_rules! jni_new_prim_array { + ($ty:ident, $value:expr) => {{ + $crate::jni_bridge::THREAD_JNIENV.with(|env| { + $crate::jni_map_error_with_env!( + env, + paste::paste! {env.[]($value.len() as i32)} + .and_then(|array| { + let value = unsafe { std::mem::transmute($value) }; + paste::paste! {env.[](array, 0, value)} + .map(|_| array) + }) + .map(|s| $crate::jni_bridge::LocalRef(unsafe { JObject::from_raw(s.into()) })) + ) + }) + }}; +} + #[macro_export] macro_rules! jni_call { ($clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> JObject) => {{ @@ -1203,49 +1233,49 @@ impl<'a> SparkUDAFWrapperContext<'a> { method_initialize: env.get_method_id( class, "initialize", - "(I)Lscala/collection/mutable/ArrayBuffer;", + "(I)Lorg/apache/spark/sql/blaze/BufferRowsColumn;", )?, method_initialize_ret: ReturnType::Object, method_resize: env.get_method_id( class, "resize", - "(Lscala/collection/mutable/ArrayBuffer;I)V", + "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;I)V", )?, method_resize_ret: ReturnType::Primitive(Primitive::Void), method_update: env.get_method_id( class, "update", - "(Lscala/collection/mutable/ArrayBuffer;JJ)V", + "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;J[J)V", )?, method_update_ret: ReturnType::Primitive(Primitive::Void), method_merge: env.get_method_id( class, "merge", - "(Lscala/collection/mutable/ArrayBuffer;Lscala/collection/mutable/ArrayBuffer;J)V", + "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;Lorg/apache/spark/sql/blaze/BufferRowsColumn;[J)V", )?, method_merge_ret: ReturnType::Primitive(Primitive::Void), method_eval: env.get_method_id( class, "eval", - "(Lscala/collection/mutable/ArrayBuffer;JJ)V", + "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[IJ)V", )?, method_eval_ret: ReturnType::Primitive(Primitive::Void), method_serializeRows: env.get_method_id( class, "serializeRows", - "(Lscala/collection/mutable/ArrayBuffer;JJ)V", + "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;[I)[B", )?, - method_serializeRows_ret: ReturnType::Primitive(Primitive::Void), + method_serializeRows_ret: ReturnType::Array, method_deserializeRows: env.get_method_id( class, "deserializeRows", - "(Ljava/nio/ByteBuffer;)Lscala/collection/mutable/ArrayBuffer;", + "(Ljava/nio/ByteBuffer;)Lorg/apache/spark/sql/blaze/BufferRowsColumn;", )?, method_deserializeRows_ret: ReturnType::Object, method_memUsed: env.get_method_id( class, "memUsed", - "(Lscala/collection/mutable/ArrayBuffer;)I", + "(Lorg/apache/spark/sql/blaze/BufferRowsColumn;)I", )?, method_memUsed_ret: ReturnType::Primitive(Primitive::Int), }) diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index 2ef736c0d..60ed70b68 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto @@ -141,7 +141,7 @@ enum AggFunction { BLOOM_FILTER = 9; BRICKHOUSE_COLLECT = 1000; BRICKHOUSE_COMBINE_UNIQUE = 1001; - DECLARATIVE = 1002; + UDAF = 1002; } message PhysicalAggExprNode { @@ -606,10 +606,10 @@ message FetchLimit { message PhysicalRepartition { oneof RepartitionType { - PhysicalSingleRepartition single_repartition = 1; - PhysicalHashRepartition hash_repartition = 2; - PhysicalRoundRobinRepartition round_robin_repartition = 3; - PhysicalRangeRepartition range_repartition = 4; + PhysicalSingleRepartition single_repartition = 1; + PhysicalHashRepartition hash_repartition = 2; + PhysicalRoundRobinRepartition round_robin_repartition = 3; + PhysicalRangeRepartition range_repartition = 4; } } diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index 7c3b4b55e..c21c00265 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs @@ -60,7 +60,7 @@ use datafusion_ext_exprs::{ }; use datafusion_ext_plans::{ agg::{ - agg::{create_agg, create_declarative_agg}, + agg::{create_agg, create_udaf_agg}, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr, }, agg_exec::AggExec, @@ -441,10 +441,10 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .map(|expr| try_parse_physical_expr(expr, &input_schema)) .collect::, _>>()?; let agg = match AggFunction::from(agg_function) { - AggFunction::Declarative => { + AggFunction::Udaf => { let udaf = agg_node.udaf.as_ref().unwrap(); let serialized = udaf.serialized.clone(); - create_declarative_agg( + create_udaf_agg( serialized, convert_required!(udaf.return_type)?, agg_children_exprs, @@ -571,8 +571,8 @@ impl TryInto> for &protobuf::PhysicalPlanNode { protobuf::AggFunction::BrickhouseCombineUnique => { WindowFunction::Agg(AggFunction::BrickhouseCombineUnique) } - protobuf::AggFunction::Declarative => { - WindowFunction::Agg(AggFunction::Declarative) + protobuf::AggFunction::Udaf => { + WindowFunction::Agg(AggFunction::Udaf) } }, }; @@ -849,16 +849,16 @@ fn try_parse_physical_expr( // cast list values to expr type e if downcast_any!(e, Literal).is_ok() && e.data_type(input_schema)? != dt => - { - match TryCastExpr::new(e, dt.clone()).evaluate( - &RecordBatch::new_empty(input_schema.clone()), - )? { - ColumnarValue::Scalar(scalar) => { - Arc::new(Literal::new(scalar)) + { + match TryCastExpr::new(e, dt.clone()).evaluate( + &RecordBatch::new_empty(input_schema.clone()), + )? { + ColumnarValue::Scalar(scalar) => { + Arc::new(Literal::new(scalar)) + } + ColumnarValue::Array(_) => unreachable!(), } - ColumnarValue::Array(_) => unreachable!(), } - } other => other, } }) diff --git a/native-engine/blaze-serde/src/lib.rs b/native-engine/blaze-serde/src/lib.rs index b64167a42..ed193ebb1 100644 --- a/native-engine/blaze-serde/src/lib.rs +++ b/native-engine/blaze-serde/src/lib.rs @@ -138,7 +138,7 @@ impl From for AggFunction { protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter, protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect, protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique, - protobuf::AggFunction::Declarative => AggFunction::Declarative, + protobuf::AggFunction::Udaf => AggFunction::Udaf, } } } diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index f11f0e6a5..f653dc132 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -15,7 +15,7 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use arrow::{ - array::{Array, ArrayRef, AsArray, Int32Array, Int32Builder, RecordBatch}, + array::{ArrayRef, AsArray, RecordBatch}, datatypes::{DataType, Int64Type, Schema, SchemaRef}, }; use datafusion::{common::Result, physical_expr::PhysicalExpr}; @@ -78,36 +78,31 @@ impl IdxSelection<'_> { } } - pub fn to_int32_array(&self) -> Int32Array { - let mut builder = Int32Builder::with_capacity(self.len()); + pub fn to_int32_vec(&self) -> Vec { + let mut vec = Vec::with_capacity(self.len()); match self { IdxSelection::Single(idx) => { - builder.append_value(*idx as i32); + vec.push(*idx as i32); } IdxSelection::Indices(indices) => { for &idx in *indices { - builder.append_value(idx as i32); + vec.push(idx as i32); } } IdxSelection::IndicesU32(indices_u32) => { for &idx in *indices_u32 { - builder.append_value(idx as i32); + vec.push(idx as i32); } } IdxSelection::Range(start, end) => { for idx in *start..*end { - builder.append_value(idx as i32); + vec.push(idx as i32); } } } - let primitive_array = builder.finish(); - primitive_array - .as_any() - .downcast_ref::() - .cloned() - .unwrap() + vec } } @@ -304,13 +299,13 @@ pub fn create_agg( arg_list_inner_type, )?) } - AggFunction::Declarative => { + AggFunction::Udaf => { unreachable!("UDAF should be handled in create_declarative_agg") } }) } -pub fn create_declarative_agg( +pub fn create_udaf_agg( serialized: Vec, return_type: DataType, children: Vec>, diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index 5f4075349..62f48a3b8 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -75,7 +75,7 @@ pub enum AggFunction { BloomFilter, BrickhouseCollect, BrickhouseCombineUnique, - Declarative, + Udaf, } #[derive(Debug, Clone)] diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index 222875f84..d2eb46aeb 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -20,16 +20,15 @@ use std::{ }; use arrow::{ - array::{ - as_struct_array, make_array, Array, ArrayRef, AsArray, BinaryArray, Int32Array, - Int32Builder, StructArray, - }, + array::{as_struct_array, make_array, Array, ArrayRef, StructArray}, datatypes::{DataType, Field, Schema, SchemaRef}, ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, record_batch::{RecordBatch, RecordBatchOptions}, }; -use arrow_schema::FieldRef; -use blaze_jni_bridge::{jni_call, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object}; +use blaze_jni_bridge::{ + jni_call, jni_get_byte_array_len, jni_get_byte_array_region, jni_new_direct_byte_buffer, + jni_new_global_ref, jni_new_object, jni_new_prim_array, +}; use datafusion::{common::Result, physical_expr::PhysicalExpr}; use datafusion_ext_commons::{ downcast_any, @@ -37,7 +36,6 @@ use datafusion_ext_commons::{ }; use jni::objects::{GlobalRef, JObject}; use once_cell::sync::OnceCell; -use datafusion_ext_commons::io::read_bytes_slice; use crate::{ agg::{ @@ -119,7 +117,7 @@ impl Agg for SparkUDAFWrapper { let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).initialize( num_rows as i32, )-> JObject) - .unwrap(); + .unwrap(); let jcontext = self.jcontext().unwrap(); let obj = jni_new_global_ref!(rows.as_obj()).unwrap(); @@ -138,15 +136,6 @@ impl Agg for SparkUDAFWrapper { )?)) } - // todo: implemented prepare_partial_args - // fn prepare_partial_args(&self, partial_inputs: &[ArrayRef]) -> - // Result> { // cast arg1 to target data type - // Ok(vec![datafusion_ext_commons::arrow::cast::cast( - // &partial_inputs[0], - // &self.return_type, - // )?]) - // } - fn partial_update( &self, accs: &mut AccColumnRef, @@ -176,27 +165,24 @@ impl Agg for SparkUDAFWrapper { params.clone(), &RecordBatchOptions::new().with_row_count(Some(partial_arg_idx.len())), )?; + let batch_struct_array = StructArray::from(params_batch); + let mut export_ffi_batch_array = FFI_ArrowArray::new(&batch_struct_array.to_data()); + // create zipped indices let max_len = std::cmp::max(acc_idx.len(), partial_arg_idx.len()); - let mut acc_idx_builder = Int32Builder::with_capacity(max_len); - let mut partial_arg_idx_builder = Int32Builder::with_capacity(max_len); + let mut zipped_indices = Vec::with_capacity(max_len); idx_for_zipped! { - ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { - acc_idx_builder.append_value(acc_idx as i32); - partial_arg_idx_builder.append_value(partial_arg_idx as i32); + ((acc_idx, updating_acc_idx) in (acc_idx, partial_arg_idx)) => { + zipped_indices.push((acc_idx as i64) << 32 | updating_acc_idx as i64); } } - let acc_idx = acc_idx_builder.finish(); - let partial_arg_idx = partial_arg_idx_builder.finish(); - - partial_update_udaf( - self.jcontext()?, - params_batch, - accs.obj.clone(), - acc_idx, - partial_arg_idx, - )?; - Ok(()) + let zipped_indices_array = jni_new_prim_array!(long, &zipped_indices[..])?; + + jni_call!(SparkUDAFWrapperContext(self.jcontext()?.as_obj()).update( + accs.obj.as_obj(), + &mut export_ffi_batch_array as *mut FFI_ArrowArray as i64, + zipped_indices_array.as_obj(), + )-> ()) } fn partial_merge( @@ -209,36 +195,42 @@ impl Agg for SparkUDAFWrapper { let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); let merging_accs = downcast_any!(merging_accs, mut AccUnsafeRowsColumn).unwrap(); + // create zipped indices let max_len = std::cmp::max(acc_idx.len(), merging_acc_idx.len()); - let mut acc_idx_builder = Int32Builder::with_capacity(max_len); - let mut merging_acc_idx_builder = Int32Builder::with_capacity(max_len); + let mut zipped_indices = Vec::with_capacity(max_len); idx_for_zipped! { - ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { - acc_idx_builder.append_value(acc_idx as i32); - merging_acc_idx_builder.append_value(merging_acc_idx as i32); + ((acc_idx, updating_acc_idx) in (acc_idx, merging_acc_idx)) => { + zipped_indices.push((acc_idx as i64) << 32 | updating_acc_idx as i64); } } - let acc_idx = acc_idx_builder.finish(); - let merging_acc_idx = merging_acc_idx_builder.finish(); - - partial_merge_udaf( - self.jcontext()?, - accs.obj.clone(), - merging_accs.obj.clone(), - acc_idx, - merging_acc_idx, - )?; - Ok(()) + let zipped_indices_array = jni_new_prim_array!(long, &zipped_indices[..])?; + + jni_call!(SparkUDAFWrapperContext(self.jcontext()?.as_obj()).merge( + accs.obj.as_obj(), + merging_accs.obj.as_obj(), + zipped_indices_array.as_obj(), + )-> ()) } fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); - final_merge_udaf( - self.jcontext()?, - accs.obj.clone(), - acc_idx, - self.import_schema.clone(), - ) + let acc_indices = acc_idx.to_int32_vec(); + + let acc_idx_array = jni_new_prim_array!(int, &acc_indices[..])?; + let mut import_ffi_array = FFI_ArrowArray::empty(); + + jni_call!(SparkUDAFWrapperContext(self.jcontext()?.as_obj()).eval( + accs.obj.as_obj(), + acc_idx_array.as_obj(), + &mut import_ffi_array as *mut FFI_ArrowArray as i64, + )-> ())?; + + // import output from context + let import_ffi_schema = FFI_ArrowSchema::try_from(self.import_schema.as_ref())?; + let import_struct_array = + make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); + let import_array = as_struct_array(&import_struct_array).column(0).clone(); + Ok(import_array) } } @@ -258,7 +250,7 @@ impl AccColumn for AccUnsafeRowsColumn { self.obj.as_obj(), len as i32, )-> ()) - .unwrap(); + .unwrap(); self.num_rows = len; } @@ -272,31 +264,23 @@ impl AccColumn for AccUnsafeRowsColumn { jni_call!( SparkUDAFWrapperContext(self.jcontext.as_obj()).memUsed( self.obj.as_obj()) - -> i32).unwrap() as usize + -> i32) + .unwrap() as usize } fn freeze_to_rows(&self, idx: IdxSelection<'_>, array: &mut [Vec]) -> Result<()> { - let idx_array: ArrayRef = Arc::new(idx.to_int32_array()); - let struct_array = - StructArray::from(RecordBatch::try_new(index_schema(), vec![idx_array])?); - let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); - let mut import_ffi_array = FFI_ArrowArray::empty(); - jni_call!( + let idx_array = jni_new_prim_array!(int, &idx.to_int32_vec()[..])?; + let serialized = jni_call!( SparkUDAFWrapperContext(self.jcontext.as_obj()).serializeRows( self.obj.as_obj(), - &mut export_ffi_array as *mut FFI_ArrowArray as i64, - &mut import_ffi_array as *mut FFI_ArrowArray as i64,) - -> ())?; - // import output from context - let import_ffi_schema = FFI_ArrowSchema::try_from(serialized_row_schema().as_ref())?; - let import_struct_array = - make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); - let result_struct = import_struct_array.as_struct(); - - let binary_array = downcast_any!(result_struct.column(0), BinaryArray)?; - let data = binary_array.value(0); + idx_array.as_obj(), + ) -> JObject)?; + let serialized_len = jni_get_byte_array_len!(serialized.as_obj())?; + let mut serialized_bytes = vec![0; serialized_len]; + jni_get_byte_array_region!(serialized.as_obj(), 0, &mut serialized_bytes[..])?; // UnsafeRow is serialized with big-endian i32 length prefix + let data = &serialized_bytes; let mut cur = 0; for i in 0..array.len() { let bytes_len = i32::from_be_bytes(data[cur..][..4].try_into().unwrap()) as usize; @@ -330,37 +314,26 @@ impl AccColumn for AccUnsafeRowsColumn { } fn spill(&self, idx: IdxSelection<'_>, buf: &mut SpillCompressedWriter) -> Result<()> { - let idx_array: ArrayRef = Arc::new(idx.to_int32_array()); - let struct_array = - StructArray::from(RecordBatch::try_new(index_schema(), vec![idx_array])?); - let mut export_ffi_array = FFI_ArrowArray::new(&struct_array.to_data()); - let mut import_ffi_array = FFI_ArrowArray::empty(); - jni_call!( + let idx_array = jni_new_prim_array!(int, &idx.to_int32_vec()[..])?; + let serialized = jni_call!( SparkUDAFWrapperContext(self.jcontext.as_obj()).serializeRows( self.obj.as_obj(), - &mut export_ffi_array as *mut FFI_ArrowArray as i64, - &mut import_ffi_array as *mut FFI_ArrowArray as i64,) - -> ())?; - // import output from context - let import_ffi_schema = FFI_ArrowSchema::try_from(serialized_row_schema().as_ref())?; - let import_struct_array = - make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); - let result_struct = import_struct_array.as_struct(); - - let binary_array = downcast_any!(result_struct.column(0), BinaryArray)?; - let data = binary_array.value(0); - buf.write(data)?; + idx_array.as_obj(), + ) -> JObject)?; + let serialized_len = jni_get_byte_array_len!(serialized.as_obj())?; + let mut serialized_bytes = vec![0; serialized_len]; + jni_get_byte_array_region!(serialized.as_obj(), 0, &mut serialized_bytes[..])?; + + write_len(serialized_bytes.len(), buf)?; + buf.write(&serialized_bytes)?; Ok(()) } fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> { + let data_size = read_len(r)?; let mut data = vec![]; - for i in 0..num_rows { - let bytes_len = read_bytes_slice(r, 4)?; - let length = i32::from_be_bytes(bytes_len.as_ref().try_into().unwrap()); - data.extend_from_slice(bytes_len.as_ref()); - data.extend_from_slice( read_bytes_slice(r, length as usize)?.as_ref()); - } + read_bytes_into_vec(r, &mut data, data_size)?; + let data_buffer = jni_new_direct_byte_buffer!(data)?; let rows = jni_call!(SparkUDAFWrapperContext(self.jcontext.as_obj()) .deserializeRows(data_buffer.as_obj()) -> JObject)?; @@ -369,113 +342,3 @@ impl AccColumn for AccUnsafeRowsColumn { Ok(()) } } - -fn int32_field() -> FieldRef { - static FIELD: OnceCell = OnceCell::new(); - FIELD - .get_or_init(|| Arc::new(Field::new("", DataType::Int32, false))) - .clone() -} - -fn binary_field() -> FieldRef { - static FIELD: OnceCell = OnceCell::new(); - FIELD - .get_or_init(|| Arc::new(Field::new("", DataType::Binary, false))) - .clone() -} - -fn index_schema() -> SchemaRef { - static SCHEMA: OnceCell = OnceCell::new(); - SCHEMA - .get_or_init(|| Arc::new(Schema::new(vec![int32_field()]))) - .clone() -} - -fn index_tuple_schema() -> SchemaRef { - static SCHEMA: OnceCell = OnceCell::new(); - SCHEMA - .get_or_init(|| Arc::new(Schema::new(vec![int32_field(), int32_field()]))) - .clone() -} - -fn serialized_row_schema() -> SchemaRef { - static SCHEMA: OnceCell = OnceCell::new(); - SCHEMA - .get_or_init(|| Arc::new(Schema::new(vec![binary_field()]))) - .clone() -} - -fn partial_update_udaf( - jcontext: GlobalRef, - params_batch: RecordBatch, - accs: GlobalRef, - acc_idx: Int32Array, - partial_arg_idx: Int32Array, -) -> Result<()> { - let acc_idx: ArrayRef = Arc::new(acc_idx); - let partial_arg_idx: ArrayRef = Arc::new(partial_arg_idx); - let idx_struct_array = StructArray::from(RecordBatch::try_new( - index_tuple_schema(), - vec![acc_idx, partial_arg_idx], - )?); - let batch_struct_array = StructArray::from(params_batch); - - let mut export_ffi_idx_array = FFI_ArrowArray::new(&idx_struct_array.to_data()); - let mut export_ffi_batch_array = FFI_ArrowArray::new(&batch_struct_array.to_data()); - - jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).update( - accs.as_obj(), - &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, - &mut export_ffi_batch_array as *mut FFI_ArrowArray as i64, - )-> ())?; - - Ok(()) -} - -fn partial_merge_udaf( - jcontext: GlobalRef, - accs: GlobalRef, - merging_accs: GlobalRef, - acc_idx: Int32Array, - merging_acc_idx: Int32Array, -) -> Result<()> { - let acc_idx: ArrayRef = Arc::new(acc_idx); - let merging_acc_idx: ArrayRef = Arc::new(merging_acc_idx); - let idx_struct_array = StructArray::from(RecordBatch::try_new( - index_tuple_schema(), - vec![acc_idx, merging_acc_idx], - )?); - let mut export_ffi_idx_array = FFI_ArrowArray::new(&idx_struct_array.to_data()); - - jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).merge( - accs.as_obj(), - merging_accs.as_obj(), - &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, - )-> ())?; - - Ok(()) -} - -fn final_merge_udaf( - jcontext: GlobalRef, - accs: GlobalRef, - acc_idx: IdxSelection<'_>, - result_schema: SchemaRef, -) -> Result { - let acc_idx: ArrayRef = Arc::new(Int32Array::from(acc_idx.to_int32_array())); - let idx_struct_array = StructArray::from(RecordBatch::try_new(index_schema(), vec![acc_idx])?); - let mut export_ffi_idx_array = FFI_ArrowArray::new(&idx_struct_array.to_data()); - let mut import_ffi_array = FFI_ArrowArray::empty(); - jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).eval( - accs.as_obj(), - &mut export_ffi_idx_array as *mut FFI_ArrowArray as i64, - &mut import_ffi_array as *mut FFI_ArrowArray as i64, - )-> ())?; - - // import output from context - let import_ffi_schema = FFI_ArrowSchema::try_from(result_schema.as_ref())?; - let import_struct_array = - make_array(unsafe { from_ffi(import_ffi_array, &import_ffi_schema)? }); - let import_array = as_struct_array(&import_struct_array).column(0).clone(); - Ok(import_array) -} diff --git a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs index fd86c0be3..cb82e0235 100644 --- a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs @@ -408,11 +408,7 @@ impl HeapByteBufferReader { fn read_impl(&mut self, buf: &mut [u8]) -> Result { let read_len = buf.len().min(self.remaining); - jni_get_byte_array_region!( - self.byte_array.as_obj().cast(), - self.pos, - &mut buf[..read_len] - )?; + jni_get_byte_array_region!(self.byte_array.as_obj(), self.pos, &mut buf[..read_len])?; self.pos += read_len; self.remaining -= read_len; Ok(read_len) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index b679971f8..0d5e8d5fa 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -1151,10 +1151,11 @@ object NativeConverters extends Logging { aggBuilder.addChildren(convertExpr(udaf.children.head)) // other udaf aggFunction case udaf - if classOf[DeclarativeAggregate].isAssignableFrom(e.aggregateFunction.getClass) - || classOf[TypedImperativeAggregate[_]].isAssignableFrom( - e.aggregateFunction.getClass) => - aggBuilder.setAggFunction(pb.AggFunction.DECLARATIVE) + if classOf[DeclarativeAggregate].isAssignableFrom(e.aggregateFunction.getClass) + || (udaf.getClass.getName != "org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate" && + classOf[TypedImperativeAggregate[_]].isAssignableFrom( + e.aggregateFunction.getClass)) => + aggBuilder.setAggFunction(pb.AggFunction.UDAF) val convertedChildren = mutable.LinkedHashMap[pb.PhysicalExprNode, BoundReference]() val bound = udaf match { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 36229e45c..d6578bf46 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -15,13 +15,14 @@ */ package org.apache.spark.sql.blaze -import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.io.ByteArrayOutputStream +import java.io.DataOutputStream import java.nio.ByteBuffer -import scala.collection.JavaConverters._ + import scala.collection.mutable.ArrayBuffer + import org.apache.arrow.c.ArrowArray import org.apache.arrow.c.Data -import org.apache.arrow.vector.IntVector import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider @@ -34,23 +35,22 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.expressions.Nondeterministic import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowWriter import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils.ROOT_ALLOCATOR import org.apache.spark.sql.execution.blaze.columnar.ColumnarHelper -import org.apache.spark.sql.types.{BinaryType, IntegerType, ObjectType, StructField, StructType} +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType import org.apache.spark.util.ByteBufferInputStream -case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { - import org.apache.spark.sql.blaze.SparkUDAFWrapperContext._ - - private val (expr, List(javaParamsSchema, javaBufferSchema)) = - NativeConverters.deserializeExpression[AggregateFunction, List[StructType]]({ +case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging { + private val (expr, javaParamsSchema) = + NativeConverters.deserializeExpression[AggregateFunction, StructType]({ val bytes = new Array[Byte](serialized.remaining()) serialized.get(bytes) bytes @@ -72,103 +72,70 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { case _ => } - private val aggEvaluator = expr match { - case declarative: DeclarativeAggregate => - new DeclarativeEvaluator(declarative, inputAttributes) - case imperative: TypedImperativeAggregate[_] => - new TypedImperativeEvaluator(imperative) + private val aggEvaluator = { + val evaluator = expr match { + case declarative: DeclarativeAggregate => + new DeclarativeEvaluator(declarative, inputAttributes) + case imperative: TypedImperativeAggregate[B] => + new TypedImperativeEvaluator(imperative) + } + evaluator.asInstanceOf[AggregateEvaluator[B]] } private val dictionaryProvider: DictionaryProvider = new MapDictionaryProvider() private val inputSchema = ArrowUtils.toArrowSchema(javaParamsSchema) - private val paramsToUnsafe = { - val toUnsafe = UnsafeProjection.create(javaParamsSchema) - toUnsafe.initialize(Option(TaskContext.get()).map(_.partitionId()).getOrElse(0)) - toUnsafe - } - def initialize(numRow: Int): ArrayBuffer[InternalRow] = { - val rows = ArrayBuffer[InternalRow]() - resize(rows, numRow) + def initialize(numRow: Int): BufferRowsColumn[B] = { + val rows = aggEvaluator.createEmptyColumn() + rows.resize(numRow, aggEvaluator.initialize()) rows } - def resize(rows: ArrayBuffer[InternalRow], len: Int): Unit = { - if (rows.length < len) { - rows.append(Range(rows.length, len).map(_ => aggEvaluator.initialize()): _*) - } else { - rows.trimEnd(rows.length - len) - } + def resize(rows: BufferRowsColumn[B], len: Int): Unit = { + rows.resize(len, aggEvaluator.initialize()) } def update( - rows: ArrayBuffer[InternalRow], - importIdxFFIArrayPtr: Long, - importBatchFFIArrayPtr: Long): Unit = { + rows: BufferRowsColumn[B], + importBatchFFIArrayPtr: Long, + zippedIndices: Array[Long]): Unit = { Using.resources( VectorSchemaRoot.create(inputSchema, ROOT_ALLOCATOR), - VectorSchemaRoot.create(indexTupleSchema, ROOT_ALLOCATOR), - ArrowArray.wrap(importBatchFFIArrayPtr), - ArrowArray.wrap(importIdxFFIArrayPtr)) { (inputRoot, idxRoot, inputArray, idxArray) => + ArrowArray.wrap(importBatchFFIArrayPtr)) { (inputRoot, inputArray) => // import into params root Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, inputArray, inputRoot, dictionaryProvider) - val inputRows = ColumnarHelper.rootRowsArray(inputRoot) - - Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, idxArray, idxRoot, dictionaryProvider) - val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] - val inputIdxVector = fieldVectors(1).asInstanceOf[IntVector] - - for (i <- 0 until idxRoot.getRowCount) { - val rowIdx = rowIdxVector.get(i) - val row = rows(rowIdx) - val input = paramsToUnsafe(inputRows(inputIdxVector.get(i))) - rows(rowIdx) = aggEvaluator.update(row, input) + val inputRow = ColumnarHelper.rootRowReuseable(inputRoot) + + for (zippedIdx <- zippedIndices) { + val rowIdx = ((zippedIdx >> 32) & 0xffffffff).toInt + val updatingRowIdx = ((zippedIdx >> 0) & 0xffffffff).toInt + inputRow.rowId = updatingRowIdx + rows.update(rowIdx, row => aggEvaluator.update(row, inputRow)) } } } def merge( - rows: ArrayBuffer[InternalRow], - mergeRows: ArrayBuffer[InternalRow], - importIdxFFIArrayPtr: Long): Unit = { - Using.resources( - VectorSchemaRoot.create(indexTupleSchema, ROOT_ALLOCATOR), - ArrowArray.wrap(importIdxFFIArrayPtr)) { (idxRoot, idxArray) => - Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, idxArray, idxRoot, dictionaryProvider) - val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors(0).asInstanceOf[IntVector] - val mergeIdxVector = fieldVectors(1).asInstanceOf[IntVector] - - for (i <- 0 until idxRoot.getRowCount) { - val rowIdx = rowIdxVector.get(i) - val mergeIdx = mergeIdxVector.get(i) - val row = rows(rowIdx) - val mergeRow = mergeRows(mergeIdx) - rows(rowIdx) = aggEvaluator.merge(row, mergeRow) - } + rows: BufferRowsColumn[B], + mergeRows: BufferRowsColumn[B], + zippedIndices: Array[Long]): Unit = { + + for (zippedIdx <- zippedIndices) { + val rowIdx = ((zippedIdx >> 32) & 0xffffffff).toInt + val mergingRowIdx = ((zippedIdx >> 0) & 0xffffffff).toInt + rows.update(rowIdx, aggEvaluator.merge(_, mergeRows.row(mergingRowIdx))) } } - def eval( - rows: ArrayBuffer[InternalRow], - importIdxFFIArrayPtr: Long, - exportFFIArrayPtr: Long): Unit = { + def eval(rows: BufferRowsColumn[B], indices: Array[Int], exportFFIArrayPtr: Long): Unit = { Using.resources( - VectorSchemaRoot.create(indexSchema, ROOT_ALLOCATOR), VectorSchemaRoot.create(outputSchema, ROOT_ALLOCATOR), - ArrowArray.wrap(importIdxFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { (idxRoot, outputRoot, idxArray, exportArray) => - Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, idxArray, idxRoot, dictionaryProvider) - val fieldVectors = idxRoot.getFieldVectors.asScala - val rowIdxVector = fieldVectors.head.asInstanceOf[IntVector] - + ArrowArray.wrap(exportFFIArrayPtr)) { (outputRoot, exportArray) => // evaluate expression and write to output root val outputWriter = ArrowWriter.create(outputRoot) - for (i <- 0 until idxRoot.getRowCount) { - val row = rows(rowIdxVector.get(i)) - outputWriter.write(aggEvaluator.eval(row)) + for (i <- indices) { + outputWriter.write(aggEvaluator.eval(rows.row(i))) } outputWriter.finish() @@ -177,81 +144,74 @@ case class SparkUDAFWrapperContext(serialized: ByteBuffer) extends Logging { } } - def serializeRows( - rows: ArrayBuffer[InternalRow], - importFFIArrayPtr: Long, - exportFFIArrayPtr: Long): Unit = { - Using.resources( - VectorSchemaRoot.create(serializedRowSchema, ROOT_ALLOCATOR), - VectorSchemaRoot.create(indexSchema, ROOT_ALLOCATOR)) { (exportDataRoot, importIdxRoot) => - Using.resources(ArrowArray.wrap(importFFIArrayPtr), ArrowArray.wrap(exportFFIArrayPtr)) { - (importArray, exportArray) => - // import into params root - Data.importIntoVectorSchemaRoot( - ROOT_ALLOCATOR, - importArray, - importIdxRoot, - dictionaryProvider) - - // write serialized row into sequential raw bytes - val importIdxArray = importIdxRoot.getFieldVectors.get(0).asInstanceOf[IntVector] - val rowsIter = (0 until importIdxRoot.getRowCount).map(i => rows(importIdxArray.get(i))) - val serializedBytes = aggEvaluator.serializeRows(rowsIter) - - // export serialized data as a single row batch using root allocator - val outputWriter = ArrowWriter.create(exportDataRoot) - outputWriter.write(InternalRow(serializedBytes)) - outputWriter.finish() - Data.exportVectorSchemaRoot( - ROOT_ALLOCATOR, - exportDataRoot, - dictionaryProvider, - exportArray) - } - } - + def serializeRows(rows: BufferRowsColumn[B], indices: Array[Int]): Array[Byte] = { + aggEvaluator.serializeRows(indices.iterator.map(i => rows.row(i))) } - def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { - aggEvaluator.deserializeRows(dataBuffer) + def deserializeRows(dataBuffer: ByteBuffer): BufferRowsColumn[B] = { + val rows = aggEvaluator.createEmptyColumn() + rows.append(aggEvaluator.deserializeRows(dataBuffer): _*) + rows } - def memUsed(rows: ArrayBuffer[InternalRow]): Int = { - aggEvaluator.memUsed(rows) + def memUsed(rows: BufferRowsColumn[B]): Int = { + rows.memUsed } } -object SparkUDAFWrapperContext { - private val indexTupleSchema = { - val schema = StructType(Seq(StructField("", IntegerType), StructField("", IntegerType))) - ArrowUtils.toArrowSchema(schema) +abstract class BufferRowsColumn[B] { + protected var rows: ArrayBuffer[B] = ArrayBuffer[B]() + protected var rowsMemUsed: Int = 0 + + def length: Int = rows.length + def memUsed: Int = rowsMemUsed + + def resize(len: Int, initializer: => B): Unit = { + if (rows.length < len) { + for (_ <- rows.length until len) { + val newRow = initializer + rowsMemUsed += getRowMemUsage(newRow) + rows.append(newRow) + } + } else { + for (i <- len until rows.length) { + rowsMemUsed -= getRowMemUsage(rows(i)) + } + rows.trimEnd(rows.length - len) + } } - private val indexSchema = { - val schema = StructType(Seq(StructField("", IntegerType, nullable = false))) - ArrowUtils.toArrowSchema(schema) + def row(i: Int): B = rows(i) + + def append(appendedRows: B*): Unit = { + rowsMemUsed += appendedRows.map(getRowMemUsage).sum + rows.append(appendedRows: _*) } - private val serializedRowSchema = { - val schema = StructType(Seq(StructField("", BinaryType, nullable = false))) - ArrowUtils.toArrowSchema(schema) + def update(i: Int, updater: B => B): Unit = { + rowsMemUsed -= getRowMemUsage(rows(i)) + rows(i) = updater(rows(i)) + rowsMemUsed += getRowMemUsage(rows(i)) } + + def getRowMemUsage(row: B): Int } -trait AggregateEvaluator extends Logging { - def initialize(): InternalRow - def update(mutableAggBuffer: InternalRow, row: InternalRow): InternalRow - def merge(row1: InternalRow, row2: InternalRow): InternalRow - def eval(row: InternalRow): InternalRow - def serializeRows(rows: Seq[InternalRow]): Array[Byte] - def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] - def memUsed(rows: Seq[InternalRow]): Int +trait AggregateEvaluator[B] extends Logging { + def createEmptyColumn(): BufferRowsColumn[B] + def initialize(): B + def update(mutableAggBuffer: B, row: InternalRow): B + def merge(row1: B, row2: B): B + def eval(row: B): InternalRow + def serializeRows(rows: Iterator[B]): Array[Byte] + def deserializeRows(dataBuffer: ByteBuffer): Seq[B] } class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attribute]) - extends AggregateEvaluator { + extends AggregateEvaluator[UnsafeRow] { private val initializer = UnsafeProjection.create(agg.initialValues) + private val initializedRow = initializer(InternalRow.empty) private val updater = UnsafeProjection.create(agg.updateExpressions, agg.aggBufferAttributes ++ inputAttributes) @@ -263,23 +223,33 @@ class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attri private val evaluator = UnsafeProjection.create(agg.evaluateExpression :: Nil, agg.aggBufferAttributes) - override def initialize(): InternalRow = { - initializer.apply(InternalRow.empty) + private val joiner = new JoinedRow + + override def createEmptyColumn(): BufferRowsColumn[UnsafeRow] = { + new BufferRowsColumn[UnsafeRow]() { + override def getRowMemUsage(row: UnsafeRow): Int = { + row.getSizeInBytes + } + } + } + + override def initialize(): UnsafeRow = { + initializedRow.copy() } - override def update(mutableAggBuffer: InternalRow, row: InternalRow): InternalRow = { - updater(new JoinedRow(mutableAggBuffer, row)).copy() + override def update(mutableAggBuffer: UnsafeRow, row: InternalRow): UnsafeRow = { + updater(joiner(mutableAggBuffer, row)).copy() } - override def merge(row1: InternalRow, row2: InternalRow): InternalRow = { - merger(new JoinedRow(row1, row2)).copy() + override def merge(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow = { + merger(joiner(row1, row2)) } - override def eval(row: InternalRow): InternalRow = { + override def eval(row: UnsafeRow): UnsafeRow = { evaluator(row) } - override def serializeRows(rows: Seq[InternalRow]): Array[Byte] = { + override def serializeRows(rows: Iterator[UnsafeRow]): Array[Byte] = { val numFields = agg.aggBufferSchema.length val outputDataStream = new ByteArrayOutputStream() val serializer = new UnsafeRowSerializer(numFields).newInstance() @@ -292,11 +262,11 @@ class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attri outputDataStream.toByteArray } - override def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { + override def deserializeRows(dataBuffer: ByteBuffer): Seq[UnsafeRow] = { val numFields = agg.aggBufferSchema.length val deserializer = new UnsafeRowSerializer(numFields).newInstance() val inputDataStream = new ByteBufferInputStream(dataBuffer) - val rows = new ArrayBuffer[InternalRow]() + val rows = new ArrayBuffer[UnsafeRow]() Using.resource(deserializer.deserializeStream(inputDataStream)) { deser => for (row <- deser.asKeyValueIterator.map(_._2.asInstanceOf[UnsafeRow].copy())) { @@ -305,72 +275,61 @@ class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attri } rows } - - override def memUsed(rows: Seq[InternalRow]): Int = { - var mem = 0 - for (row <- rows) { - mem = mem + row.asInstanceOf[UnsafeRow].getSizeInBytes - } - mem - } } -class TypedImperativeEvaluator[T]( - agg: TypedImperativeAggregate[T]) - extends AggregateEvaluator { +class TypedImperativeEvaluator[B](agg: TypedImperativeAggregate[B]) + extends AggregateEvaluator[B] { + private val evalRow = InternalRow(0) - private val bufferSchema = agg.aggBufferAttributes.map(_.dataType) - private val anyObjectType = ObjectType(classOf[AnyRef]) + override def createEmptyColumn(): BufferRowsColumn[B] = { + new BufferRowsColumn[B]() { + override def getRowMemUsage(row: B): Int = { + 64 // estimated size of object + } - private def getBufferObject(buffer: InternalRow): T = { - buffer.get(0, anyObjectType).asInstanceOf[T] + override def update(i: Int, updater: B => B): Unit = { + rows(i) = updater(rows(i)) + } + } } - override def initialize(): InternalRow = { - val row = InternalRow(bufferSchema) - agg.initialize(row) - row + + override def initialize(): B = { + agg.createAggregationBuffer() } - override def update(buffer: InternalRow, row: InternalRow): InternalRow = { + override def update(buffer: B, row: InternalRow): B = { agg.update(buffer, row) - buffer } - override def merge(row1: InternalRow, row2: InternalRow): InternalRow = { - val Object1 = getBufferObject(row1) - val Object2 = getBufferObject(row2) - row1.update(0, agg.merge(Object1, Object2)) - row1 + override def merge(row1: B, row2: B): B = { + agg.merge(row1, row2) } - override def eval(row: InternalRow): InternalRow = { - InternalRow(agg.eval(row)) + override def eval(row: B): InternalRow = { + evalRow.update(0, agg.eval(row)) + evalRow } - override def serializeRows(rows: Seq[InternalRow]): Array[Byte] = { + override def serializeRows(rows: Iterator[B]): Array[Byte] = { val outputStream = new ByteArrayOutputStream() val dataOut = new DataOutputStream(outputStream) for (row <- rows) { - val byteBuffer = agg.serialize(row.get(0, anyObjectType).asInstanceOf[T]) + val byteBuffer = agg.serialize(row) dataOut.writeInt(byteBuffer.length) - outputStream.write(byteBuffer) + dataOut.write(byteBuffer) } outputStream.toByteArray } - override def deserializeRows(dataBuffer: ByteBuffer): ArrayBuffer[InternalRow] = { - val rows = ArrayBuffer[InternalRow]() + override def deserializeRows(dataBuffer: ByteBuffer): Seq[B] = { + val rows = ArrayBuffer[B]() while (dataBuffer.hasRemaining) { val length = dataBuffer.getInt() val byteBuffer = new Array[Byte](length) dataBuffer.get(byteBuffer) - val row = InternalRow(agg.deserialize(byteBuffer)) + val row = agg.deserialize(byteBuffer) rows.append(row) } rows } - - override def memUsed(rows: Seq[InternalRow]): Int = { - rows.length * 192 - } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala index c5b9f53d0..fea07de42 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/columnar/ColumnarHelper.scala @@ -18,30 +18,20 @@ package org.apache.spark.sql.execution.blaze.columnar import scala.collection.JavaConverters._ import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.spark.sql.catalyst.InternalRow object ColumnarHelper { - def rootRowsIter(root: VectorSchemaRoot): Iterator[InternalRow] = { - val vectors = root.getFieldVectors.asScala.toArray + def rootRowsIter(root: VectorSchemaRoot): Iterator[BlazeColumnarBatchRow] = { + val row = rootRowReuseable(root) val numRows = root.getRowCount - val row = new BlazeColumnarBatchRow( - vectors.map(new BlazeArrowColumnVector(_).asInstanceOf[BlazeColumnVector])) - Range(0, numRows).iterator.map { rowId => row.rowId = rowId - row.asInstanceOf[InternalRow] + row } } - def rootRowsArray(root: VectorSchemaRoot): Array[InternalRow] = { + def rootRowReuseable(root: VectorSchemaRoot): BlazeColumnarBatchRow = { val vectors = root.getFieldVectors.asScala.toArray - val numRows = root.getRowCount - val row = new BlazeColumnarBatchRow( + new BlazeColumnarBatchRow( vectors.map(new BlazeArrowColumnVector(_).asInstanceOf[BlazeColumnVector])) - (0 until numRows).map { rowId => - row.rowId = rowId - row.asInstanceOf[InternalRow] - }.toArray } - } From 3705484dee83b245c47e2589b2e1d7bb2dd46cd3 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Tue, 25 Feb 2025 16:19:50 +0800 Subject: [PATCH 16/17] update NativeConverters and formet --- native-engine/blaze-serde/src/from_proto.rs | 16 +++++++-------- .../src/agg/spark_udaf_wrapper.rs | 6 +++--- .../spark/sql/blaze/NativeConverters.scala | 20 +++++++------------ 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index c21c00265..1ed50ab9e 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs @@ -849,16 +849,16 @@ fn try_parse_physical_expr( // cast list values to expr type e if downcast_any!(e, Literal).is_ok() && e.data_type(input_schema)? != dt => - { - match TryCastExpr::new(e, dt.clone()).evaluate( - &RecordBatch::new_empty(input_schema.clone()), - )? { - ColumnarValue::Scalar(scalar) => { - Arc::new(Literal::new(scalar)) - } - ColumnarValue::Array(_) => unreachable!(), + { + match TryCastExpr::new(e, dt.clone()).evaluate( + &RecordBatch::new_empty(input_schema.clone()), + )? { + ColumnarValue::Scalar(scalar) => { + Arc::new(Literal::new(scalar)) } + ColumnarValue::Array(_) => unreachable!(), } + } other => other, } }) diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index d2eb46aeb..77c988d0f 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -117,7 +117,7 @@ impl Agg for SparkUDAFWrapper { let rows = jni_call!(SparkUDAFWrapperContext(jcontext.as_obj()).initialize( num_rows as i32, )-> JObject) - .unwrap(); + .unwrap(); let jcontext = self.jcontext().unwrap(); let obj = jni_new_global_ref!(rows.as_obj()).unwrap(); @@ -250,7 +250,7 @@ impl AccColumn for AccUnsafeRowsColumn { self.obj.as_obj(), len as i32, )-> ()) - .unwrap(); + .unwrap(); self.num_rows = len; } @@ -265,7 +265,7 @@ impl AccColumn for AccUnsafeRowsColumn { SparkUDAFWrapperContext(self.jcontext.as_obj()).memUsed( self.obj.as_obj()) -> i32) - .unwrap() as usize + .unwrap() as usize } fn freeze_to_rows(&self, idx: IdxSelection<'_>, array: &mut [Vec]) -> Result<()> { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index 0d5e8d5fa..d161e4787 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -1149,12 +1149,13 @@ object NativeConverters extends Logging { defaultValue = true) => aggBuilder.setAggFunction(pb.AggFunction.BRICKHOUSE_COMBINE_UNIQUE) aggBuilder.addChildren(convertExpr(udaf.children.head)) - // other udaf aggFunction - case udaf - if classOf[DeclarativeAggregate].isAssignableFrom(e.aggregateFunction.getClass) - || (udaf.getClass.getName != "org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate" && - classOf[TypedImperativeAggregate[_]].isAssignableFrom( - e.aggregateFunction.getClass)) => + + case udaf => + Shims.get.convertMoreAggregateExpr(e) match { + case Some(converted) => return converted + case _ => + } + // other udaf aggFunction aggBuilder.setAggFunction(pb.AggFunction.UDAF) val convertedChildren = mutable.LinkedHashMap[pb.PhysicalExprNode, BoundReference]() @@ -1195,13 +1196,6 @@ object NativeConverters extends Logging { .setReturnType(convertDataType(bound.dataType)) .setReturnNullable(bound.nullable)) aggBuilder.addAllChildren(convertedChildren.keys.asJava) - - case _ => - Shims.get.convertMoreAggregateExpr(e) match { - case Some(converted) => return converted - case _ => - } - throw new NotImplementedError(s"unsupported aggregate expression: (${e.getClass}) $e") } pb.PhysicalExprNode .newBuilder() From c5a1ea83b2a018bed59f4132f1386614a1c80fd6 Mon Sep 17 00:00:00 2001 From: guoying06 Date: Tue, 25 Feb 2025 17:47:00 +0800 Subject: [PATCH 17/17] formate name and schema --- native-engine/blaze-serde/proto/blaze.proto | 5 +-- native-engine/blaze-serde/src/from_proto.rs | 2 ++ .../datafusion-ext-plans/src/agg/agg.rs | 5 +-- .../datafusion-ext-plans/src/agg/agg_ctx.rs | 8 +---- .../datafusion-ext-plans/src/agg/avg.rs | 19 +++------- .../src/agg/bloom_filter.rs | 2 -- .../src/agg/brickhouse/collect.rs | 3 -- .../src/agg/brickhouse/combine_unique.rs | 3 -- .../datafusion-ext-plans/src/agg/collect.rs | 1 - .../datafusion-ext-plans/src/agg/count.rs | 1 - .../datafusion-ext-plans/src/agg/first.rs | 1 - .../src/agg/first_ignores_null.rs | 1 - .../datafusion-ext-plans/src/agg/maxmin.rs | 1 - .../src/agg/spark_udaf_wrapper.rs | 35 +++++++------------ .../datafusion-ext-plans/src/agg/sum.rs | 1 - .../src/window/processors/agg_processor.rs | 1 - .../spark/sql/blaze/NativeConverters.scala | 3 +- .../sql/blaze/SparkUDAFWrapperContext.scala | 16 ++++----- .../sql/blaze/SparkUDFWrapperContext.scala | 11 +++--- .../sql/blaze/SparkUDTFWrapperContext.scala | 11 +++--- 20 files changed, 47 insertions(+), 83 deletions(-) diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index 60ed70b68..76b2a1bdd 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto @@ -152,8 +152,9 @@ message PhysicalAggExprNode { message AggUdaf { bytes serialized = 1; - ArrowType return_type = 2; - bool return_nullable = 3; + Schema input_schema = 2; + ArrowType return_type = 3; + bool return_nullable = 4; } message PhysicalIsNull { diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index 1ed50ab9e..db1c55749 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs @@ -444,8 +444,10 @@ impl TryInto> for &protobuf::PhysicalPlanNode { AggFunction::Udaf => { let udaf = agg_node.udaf.as_ref().unwrap(); let serialized = udaf.serialized.clone(); + let input_schema = Arc::new(convert_required!(udaf.input_schema)?); create_udaf_agg( serialized, + input_schema, convert_required!(udaf.return_type)?, agg_children_exprs, )? diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index f653dc132..1910e16b1 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -46,7 +46,6 @@ pub trait Agg: Send + Sync + Debug { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, ) -> Result<()>; fn partial_merge( @@ -300,18 +299,20 @@ pub fn create_agg( )?) } AggFunction::Udaf => { - unreachable!("UDAF should be handled in create_declarative_agg") + unreachable!("UDAF should be handled in create_udaf_agg") } }) } pub fn create_udaf_agg( serialized: Vec, + input_schema: SchemaRef, return_type: DataType, children: Vec>, ) -> Result> { Ok(Arc::new(SparkUDAFWrapper::try_new( serialized, + input_schema, return_type, children, )?)) diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs b/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs index 8d39dd1db..4aa7f6fc0 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs @@ -333,13 +333,7 @@ impl AggContext { if self.need_partial_update { for (agg_idx, agg) in &self.need_partial_update_aggs { let acc_col = &mut acc_table.cols_mut()[*agg_idx]; - agg.partial_update( - acc_col, - acc_idx, - &input_arrays[*agg_idx], - input_idx, - batch_schema.clone(), - )?; + agg.partial_update(acc_col, acc_idx, &input_arrays[*agg_idx], input_idx)?; } } Ok(()) diff --git a/native-engine/datafusion-ext-plans/src/agg/avg.rs b/native-engine/datafusion-ext-plans/src/agg/avg.rs index 58fe571f8..7e2496cbe 100644 --- a/native-engine/datafusion-ext-plans/src/agg/avg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/avg.rs @@ -110,23 +110,12 @@ impl Agg for AggAvg { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccAvgColumn).unwrap(); - self.agg_sum.partial_update( - &mut accs.sum, - acc_idx, - partial_args, - partial_arg_idx, - batch_schema.clone(), - )?; - self.agg_count.partial_update( - &mut accs.count, - acc_idx, - partial_args, - partial_arg_idx, - batch_schema.clone(), - )?; + self.agg_sum + .partial_update(&mut accs.sum, acc_idx, partial_args, partial_arg_idx)?; + self.agg_count + .partial_update(&mut accs.count, acc_idx, partial_args, partial_arg_idx)?; Ok(()) } diff --git a/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs b/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs index 04fb7c5e1..f4bd34580 100644 --- a/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs +++ b/native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs @@ -23,7 +23,6 @@ use arrow::{ array::{ArrayRef, AsArray, BinaryBuilder}, datatypes::{DataType, Int64Type}, }; -use arrow_schema::SchemaRef; use byteorder::{ReadBytesExt, WriteBytesExt}; use datafusion::{common::Result, physical_expr::PhysicalExpr}; use datafusion_ext_commons::{ @@ -114,7 +113,6 @@ impl Agg for AggBloomFilter { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccBloomFilterColumn).unwrap(); let bloom_filter = match acc_idx { diff --git a/native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs b/native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs index 5216df746..c74cf5dd7 100644 --- a/native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs +++ b/native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs @@ -22,7 +22,6 @@ use arrow::{ array::{Array, ArrayRef, AsArray}, datatypes::DataType, }; -use arrow_schema::SchemaRef; use datafusion::{common::Result, physical_expr::PhysicalExpr}; use crate::{ @@ -87,7 +86,6 @@ impl Agg for AggCollect { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, ) -> Result<()> { let list = partial_args[0].as_list::(); @@ -101,7 +99,6 @@ impl Agg for AggCollect { IdxSelection::Single(acc_idx), &[values], IdxSelection::Range(0, values_len), - batch_schema.clone(), )?; } } diff --git a/native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs b/native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs index 900f2b43c..1b8b8246f 100644 --- a/native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs +++ b/native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs @@ -22,7 +22,6 @@ use arrow::{ array::{Array, ArrayRef, AsArray}, datatypes::DataType, }; -use arrow_schema::SchemaRef; use datafusion::{common::Result, physical_expr::PhysicalExpr}; use crate::{ @@ -87,7 +86,6 @@ impl Agg for AggCombineUnique { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, ) -> Result<()> { let list = partial_args[0].as_list::(); @@ -101,7 +99,6 @@ impl Agg for AggCombineUnique { IdxSelection::Single(acc_idx), &[values], IdxSelection::Range(0, values_len), - batch_schema.clone(), )?; } } diff --git a/native-engine/datafusion-ext-plans/src/agg/collect.rs b/native-engine/datafusion-ext-plans/src/agg/collect.rs index c498a3e30..de2dca821 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect.rs @@ -114,7 +114,6 @@ impl Agg for AggGenericCollect { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut C).unwrap(); idx_for_zipped! { diff --git a/native-engine/datafusion-ext-plans/src/agg/count.rs b/native-engine/datafusion-ext-plans/src/agg/count.rs index 80f6a30ca..e90ff2fcd 100644 --- a/native-engine/datafusion-ext-plans/src/agg/count.rs +++ b/native-engine/datafusion-ext-plans/src/agg/count.rs @@ -92,7 +92,6 @@ impl Agg for AggCount { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccCountColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/first.rs b/native-engine/datafusion-ext-plans/src/agg/first.rs index 5c0496a52..a7083b53b 100644 --- a/native-engine/datafusion-ext-plans/src/agg/first.rs +++ b/native-engine/datafusion-ext-plans/src/agg/first.rs @@ -90,7 +90,6 @@ impl Agg for AggFirst { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - _batch_schema: SchemaRef, ) -> Result<()> { let partial_arg = &partial_args[0]; let accs = downcast_any!(accs, mut AccFirstColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs b/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs index c6a5e6be0..ca898a5b1 100644 --- a/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs +++ b/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs @@ -86,7 +86,6 @@ impl Agg for AggFirstIgnoresNull { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - _batch_schema: SchemaRef, ) -> Result<()> { let partial_arg = &partial_args[0]; let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs index 20e2889c3..24800b1c4 100644 --- a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs +++ b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs @@ -93,7 +93,6 @@ impl Agg for AggMaxMin

{ acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); let old_heap_mem_used = accs.items_heap_mem_used(acc_idx); diff --git a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs index 77c988d0f..54222bd49 100644 --- a/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs +++ b/native-engine/datafusion-ext-plans/src/agg/spark_udaf_wrapper.rs @@ -51,13 +51,14 @@ pub struct SparkUDAFWrapper { pub return_type: DataType, child: Vec>, import_schema: SchemaRef, - params_schema: OnceCell, + params_schema: SchemaRef, jcontext: OnceCell, } impl SparkUDAFWrapper { pub fn try_new( serialized: Vec, + input_schema: SchemaRef, return_type: DataType, child: Vec>, ) -> Result { @@ -66,7 +67,7 @@ impl SparkUDAFWrapper { return_type: return_type.clone(), child, import_schema: Arc::new(Schema::new(vec![Field::new("", return_type, true)])), - params_schema: OnceCell::new(), + params_schema: input_schema, jcontext: OnceCell::new(), }) } @@ -121,7 +122,7 @@ impl Agg for SparkUDAFWrapper { let jcontext = self.jcontext().unwrap(); let obj = jni_new_global_ref!(rows.as_obj()).unwrap(); - Box::new(AccUnsafeRowsColumn { + Box::new(AccUDAFBufferRowsColumn { obj, jcontext, num_rows, @@ -131,6 +132,7 @@ impl Agg for SparkUDAFWrapper { fn with_new_exprs(&self, _exprs: Vec>) -> Result> { Ok(Arc::new(Self::try_new( self.serialized.clone(), + self.params_schema.clone(), self.return_type.clone(), self.child.clone(), )?)) @@ -142,24 +144,11 @@ impl Agg for SparkUDAFWrapper { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - batch_schema: SchemaRef, ) -> Result<()> { - let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); + let accs = downcast_any!(accs, mut AccUDAFBufferRowsColumn).unwrap(); let params = partial_args.to_vec(); - let params_schema = self - .params_schema - .get_or_try_init(|| -> Result { - let mut param_fields = Vec::with_capacity(self.child.len()); - for child in &self.child { - param_fields.push(Field::new( - "", - child.data_type(batch_schema.as_ref())?, - child.nullable(batch_schema.as_ref())?, - )); - } - Ok(Arc::new(Schema::new(param_fields))) - })?; + let params_schema = self.params_schema.clone(); let params_batch = RecordBatch::try_new_with_options( params_schema.clone(), params.clone(), @@ -192,8 +181,8 @@ impl Agg for SparkUDAFWrapper { merging_accs: &mut AccColumnRef, merging_acc_idx: IdxSelection<'_>, ) -> Result<()> { - let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); - let merging_accs = downcast_any!(merging_accs, mut AccUnsafeRowsColumn).unwrap(); + let accs = downcast_any!(accs, mut AccUDAFBufferRowsColumn).unwrap(); + let merging_accs = downcast_any!(merging_accs, mut AccUDAFBufferRowsColumn).unwrap(); // create zipped indices let max_len = std::cmp::max(acc_idx.len(), merging_acc_idx.len()); @@ -213,7 +202,7 @@ impl Agg for SparkUDAFWrapper { } fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { - let accs = downcast_any!(accs, mut AccUnsafeRowsColumn).unwrap(); + let accs = downcast_any!(accs, mut AccUDAFBufferRowsColumn).unwrap(); let acc_indices = acc_idx.to_int32_vec(); let acc_idx_array = jni_new_prim_array!(int, &acc_indices[..])?; @@ -234,13 +223,13 @@ impl Agg for SparkUDAFWrapper { } } -struct AccUnsafeRowsColumn { +struct AccUDAFBufferRowsColumn { obj: GlobalRef, jcontext: GlobalRef, num_rows: usize, } -impl AccColumn for AccUnsafeRowsColumn { +impl AccColumn for AccUDAFBufferRowsColumn { fn as_any_mut(&mut self) -> &mut dyn Any { self } diff --git a/native-engine/datafusion-ext-plans/src/agg/sum.rs b/native-engine/datafusion-ext-plans/src/agg/sum.rs index 5e073553f..f0b0f4f32 100644 --- a/native-engine/datafusion-ext-plans/src/agg/sum.rs +++ b/native-engine/datafusion-ext-plans/src/agg/sum.rs @@ -91,7 +91,6 @@ impl Agg for AggSum { acc_idx: IdxSelection<'_>, partial_args: &[ArrayRef], partial_arg_idx: IdxSelection<'_>, - _batch_schema: SchemaRef, ) -> Result<()> { let accs = downcast_any!(accs, mut AccGenericColumn).unwrap(); diff --git a/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs index af2f75943..89d2b56ba 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs @@ -78,7 +78,6 @@ impl WindowFunctionProcessor for AggProcessor { IdxSelection::Single(0), &children_cols, IdxSelection::Single(row_idx), - batch.schema(), )?; output.push( self.agg diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index d161e4787..192659601 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -28,7 +28,7 @@ import org.apache.spark.SparkEnv import org.blaze.{protobuf => pb} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Abs, Acos, Add, Alias, And, Asin, Atan, Attribute, AttributeReference, BitwiseAnd, BitwiseOr, BoundReference, CaseWhen, Cast, Ceil, CheckOverflow, Coalesce, Concat, ConcatWs, Contains, Cos, CreateArray, CreateNamedStruct, DayOfMonth, Divide, EndsWith, EqualTo, Exp, Expression, Floor, GetArrayItem, GetJsonObject, GetMapValue, GetStructField, GreaterThan, GreaterThanOrEqual, If, In, InSet, IsNotNull, IsNull, LeafExpression, Length, LessThan, LessThanOrEqual, Like, Literal, Log, Log10, Log2, Lower, MakeDecimal, Md5, Month, Multiply, Murmur3Hash, Not, NullIf, OctetLength, Or, Remainder, Sha2, ShiftLeft, ShiftRight, Signum, Sin, Sqrt, StartsWith, StringRepeat, StringSpace, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, Tan, TruncDate, Unevaluable, UnscaledValue, Upper, XxHash64, Year} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, ImperativeAggregate, Max, Min, Sum, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, ImperativeAggregate, Max, Min, Sum} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode import org.apache.spark.sql.catalyst.plans.FullOuter @@ -1193,6 +1193,7 @@ object NativeConverters extends Logging { pb.AggUdaf .newBuilder() .setSerialized(ByteString.copyFrom(serialized)) + .setInputSchema(NativeConverters.convertSchema(paramsSchema)) .setReturnType(convertDataType(bound.dataType)) .setReturnNullable(bound.nullable)) aggBuilder.addAllChildren(convertedChildren.keys.asJava) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index d6578bf46..76e056ba9 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -97,9 +97,9 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging { } def update( - rows: BufferRowsColumn[B], - importBatchFFIArrayPtr: Long, - zippedIndices: Array[Long]): Unit = { + rows: BufferRowsColumn[B], + importBatchFFIArrayPtr: Long, + zippedIndices: Array[Long]): Unit = { Using.resources( VectorSchemaRoot.create(inputSchema, ROOT_ALLOCATOR), ArrowArray.wrap(importBatchFFIArrayPtr)) { (inputRoot, inputArray) => @@ -117,9 +117,9 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging { } def merge( - rows: BufferRowsColumn[B], - mergeRows: BufferRowsColumn[B], - zippedIndices: Array[Long]): Unit = { + rows: BufferRowsColumn[B], + mergeRows: BufferRowsColumn[B], + zippedIndices: Array[Long]): Unit = { for (zippedIdx <- zippedIndices) { val rowIdx = ((zippedIdx >> 32) & 0xffffffff).toInt @@ -208,7 +208,7 @@ trait AggregateEvaluator[B] extends Logging { } class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attribute]) - extends AggregateEvaluator[UnsafeRow] { + extends AggregateEvaluator[UnsafeRow] { private val initializer = UnsafeProjection.create(agg.initialValues) private val initializedRow = initializer(InternalRow.empty) @@ -278,7 +278,7 @@ class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attri } class TypedImperativeEvaluator[B](agg: TypedImperativeAggregate[B]) - extends AggregateEvaluator[B] { + extends AggregateEvaluator[B] { private val evalRow = InternalRow(0) override def createEmptyColumn(): BufferRowsColumn[B] = { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala index 60ecfd103..d078528fb 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala @@ -36,11 +36,12 @@ import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType case class SparkUDFWrapperContext(serialized: ByteBuffer) extends Logging { - private val (expr, javaParamsSchema) = NativeConverters.deserializeExpression[Expression, StructType]({ - val bytes = new Array[Byte](serialized.remaining()) - serialized.get(bytes) - bytes - }) + private val (expr, javaParamsSchema) = + NativeConverters.deserializeExpression[Expression, StructType]({ + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + bytes + }) // initialize all nondeterministic children exprs expr.foreach { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala index dc387ddef..0946d3949 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala @@ -37,11 +37,12 @@ import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType case class SparkUDTFWrapperContext(serialized: ByteBuffer) extends Logging { - private val (expr, javaParamsSchema) = NativeConverters.deserializeExpression[Generator, StructType]({ - val bytes = new Array[Byte](serialized.remaining()) - serialized.get(bytes) - bytes - }) + private val (expr, javaParamsSchema) = + NativeConverters.deserializeExpression[Generator, StructType]({ + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + bytes + }) // initialize all nondeterministic children exprs expr.foreach {