Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions native-engine/blaze-serde/proto/blaze.proto
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,12 @@ message PhysicalAggExprNode {
AggFunction agg_function = 1;
AggUdaf udaf = 2;
repeated PhysicalExprNode children = 3;
ArrowType return_type = 4;
}

message AggUdaf {
bytes serialized = 1;
Schema input_schema = 2;
ArrowType return_type = 3;
bool return_nullable = 4;
}

message PhysicalIsNull {
Expand Down Expand Up @@ -535,6 +534,7 @@ message WindowExecNode {

message WindowExprNode {
Field field = 1;
ArrowType return_type = 1000;
WindowFunctionType func_type = 2;
WindowFunction window_func = 3;
AggFunction agg_func = 4;
Expand Down
17 changes: 11 additions & 6 deletions native-engine/blaze-serde/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,20 +459,19 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.iter()
.map(|expr| try_parse_physical_expr(expr, &input_schema))
.collect::<Result<Vec<_>, _>>()?;
let return_type = convert_required!(agg_node.return_type)?;

let agg = match AggFunction::from(agg_function) {
AggFunction::Udaf => {
let udaf = agg_node.udaf.as_ref().unwrap();
let serialized = udaf.serialized.clone();
create_udaf_agg(
serialized,
convert_required!(udaf.return_type)?,
agg_children_exprs,
)?
create_udaf_agg(serialized, return_type, agg_children_exprs)?
}
_ => create_agg(
AggFunction::from(agg_function),
&agg_children_exprs,
&input_schema,
return_type,
)?,
};

Expand Down Expand Up @@ -548,6 +547,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.iter()
.map(|expr| try_parse_physical_expr(expr, &input.schema()))
.collect::<Result<Vec<_>, Self::Error>>()?;
let return_type = convert_required!(w.return_type)?;

let window_func = match w.func_type() {
protobuf::WindowFunctionType::Window => match w.window_func() {
Expand Down Expand Up @@ -595,7 +595,12 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
}
},
};
Ok::<_, Self::Error>(WindowExpr::new(window_func, children, field))
Ok::<_, Self::Error>(WindowExpr::new(
window_func,
children,
field,
return_type,
))
})
.collect::<Result<Vec<_>, _>>()?;

Expand Down
71 changes: 29 additions & 42 deletions native-engine/datafusion-ext-plans/src/agg/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,18 @@ use datafusion_ext_commons::df_execution_err;
use datafusion_ext_exprs::cast::TryCastExpr;

use crate::agg::{
acc::AccColumnRef, avg, bloom_filter, brickhouse, collect, first, first_ignores_null, maxmin,
spark_udaf_wrapper::SparkUDAFWrapper, sum, AggFunction,
acc::AccColumnRef,
avg::AggAvg,
bloom_filter::AggBloomFilter,
brickhouse,
collect::{AggCollectList, AggCollectSet},
count::AggCount,
first::AggFirst,
first_ignores_null::AggFirstIgnoresNull,
maxmin::{AggMax, AggMin},
spark_udaf_wrapper::SparkUDAFWrapper,
sum::AggSum,
AggFunction,
};

pub trait Agg: Send + Sync + Debug {
Expand Down Expand Up @@ -161,12 +171,8 @@ pub fn create_agg(
agg_function: AggFunction,
children: &[Arc<dyn PhysicalExpr>],
input_schema: &SchemaRef,
return_type: DataType,
) -> Result<Arc<dyn Agg>> {
use arrow::datatypes::DataType;
use datafusion::logical_expr::type_coercion::aggregates::*;

use crate::agg::count;

Ok(match agg_function {
AggFunction::Count => {
let return_type = DataType::Int64;
Expand All @@ -178,48 +184,31 @@ pub fn create_agg(
})
.cloned()
.collect::<Vec<_>>();
Arc::new(count::AggCount::try_new(children, return_type)?)
}
AggFunction::Sum => {
let arg_type = children[0].data_type(input_schema)?;
let return_type = match arg_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
DataType::Int64
}
DataType::Float32 | DataType::Float64 => DataType::Float64,
other => sum_return_type(&other)?,
};
Arc::new(sum::AggSum::try_new(
Arc::new(TryCastExpr::new(children[0].clone(), return_type.clone())),
return_type,
)?)
}
AggFunction::Avg => {
let arg_type = children[0].data_type(input_schema)?;
let return_type = avg_return_type("avg", &arg_type)?;
Arc::new(avg::AggAvg::try_new(
Arc::new(TryCastExpr::new(children[0].clone(), return_type.clone())),
return_type,
)?)
Arc::new(AggCount::try_new(children, return_type)?)
}
AggFunction::Sum => Arc::new(AggSum::try_new(
Arc::new(TryCastExpr::new(children[0].clone(), return_type.clone())),
return_type,
)?),
AggFunction::Avg => Arc::new(AggAvg::try_new(
Arc::new(TryCastExpr::new(children[0].clone(), return_type.clone())),
return_type,
)?),
AggFunction::Max => {
let dt = children[0].data_type(input_schema)?;
Arc::new(maxmin::AggMax::try_new(children[0].clone(), dt)?)
Arc::new(AggMax::try_new(children[0].clone(), dt)?)
}
AggFunction::Min => {
let dt = children[0].data_type(input_schema)?;
Arc::new(maxmin::AggMin::try_new(children[0].clone(), dt)?)
Arc::new(AggMin::try_new(children[0].clone(), dt)?)
}
AggFunction::First => {
let dt = children[0].data_type(input_schema)?;
Arc::new(first::AggFirst::try_new(children[0].clone(), dt)?)
Arc::new(AggFirst::try_new(children[0].clone(), dt)?)
}
AggFunction::FirstIgnoresNull => {
let dt = children[0].data_type(input_schema)?;
Arc::new(first_ignores_null::AggFirstIgnoresNull::try_new(
children[0].clone(),
dt,
)?)
Arc::new(AggFirstIgnoresNull::try_new(children[0].clone(), dt)?)
}
AggFunction::BloomFilter => {
let dt = children[0].data_type(input_schema)?;
Expand All @@ -234,7 +223,7 @@ pub fn create_agg(
.into_array(1)?
.as_primitive::<Int64Type>()
.value(0);
Arc::new(bloom_filter::AggBloomFilter::new(
Arc::new(AggBloomFilter::new(
children[0].clone(),
dt,
estimated_num_items as usize,
Expand All @@ -243,17 +232,15 @@ pub fn create_agg(
}
AggFunction::CollectList => {
let arg_type = children[0].data_type(input_schema)?;
let return_type = DataType::new_list(arg_type.clone(), true);
Arc::new(collect::AggCollectList::try_new(
Arc::new(AggCollectList::try_new(
children[0].clone(),
return_type,
arg_type,
)?)
}
AggFunction::CollectSet => {
let arg_type = children[0].data_type(input_schema)?;
let return_type = DataType::new_list(arg_type.clone(), true);
Arc::new(collect::AggCollectSet::try_new(
Arc::new(AggCollectSet::try_new(
children[0].clone(),
return_type,
arg_type,
Expand Down
12 changes: 9 additions & 3 deletions native-engine/datafusion-ext-plans/src/agg/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use datafusion::{
physical_expr::PhysicalExpr,
};
use datafusion_ext_commons::{
downcast_any,
df_execution_err, downcast_any,
io::{read_bytes_slice, read_len, read_scalar, write_len, write_scalar},
};
use hashbrown::raw::RawTable;
Expand All @@ -49,6 +49,7 @@ pub struct AggGenericCollect<C: AccCollectionColumn> {
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
arg_type: DataType,
return_list_nullable: bool,
_phantom: PhantomData<C>,
}

Expand All @@ -58,10 +59,15 @@ impl<C: AccCollectionColumn> AggGenericCollect<C> {
data_type: DataType,
arg_type: DataType,
) -> Result<Self> {
let return_list_nullable = match &data_type {
DataType::List(field) => field.is_nullable(),
_ => return df_execution_err!("expect DataType::List({arg_type:?}, got {data_type:?}"),
};
Ok(Self {
child,
data_type,
arg_type,
data_type,
return_list_nullable,
_phantom: Default::default(),
})
}
Expand Down Expand Up @@ -157,7 +163,7 @@ impl<C: AccCollectionColumn> Agg for AggGenericCollect<C> {
list.push(ScalarValue::List(ScalarValue::new_list(
&accs.take_values(acc_idx),
&self.arg_type,
true,
self.return_list_nullable,
)));
}
}
Expand Down
10 changes: 10 additions & 0 deletions native-engine/datafusion-ext-plans/src/agg_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,60 +500,70 @@ mod test {
AggFunction::Sum,
&[phys_expr::col("a", &input.schema())?],
&input.schema(),
DataType::Int64,
)?;

let agg_expr_avg = create_agg(
AggFunction::Avg,
&[phys_expr::col("b", &input.schema())?],
&input.schema(),
DataType::Float64,
)?;

let agg_expr_max = create_agg(
AggFunction::Max,
&[phys_expr::col("d", &input.schema())?],
&input.schema(),
DataType::Int32,
)?;

let agg_expr_min = create_agg(
AggFunction::Min,
&[phys_expr::col("e", &input.schema())?],
&input.schema(),
DataType::Int32,
)?;

let agg_expr_count = create_agg(
AggFunction::Count,
&[phys_expr::col("f", &input.schema())?],
&input.schema(),
DataType::Int64,
)?;

let agg_expr_collectlist = create_agg(
AggFunction::CollectList,
&[phys_expr::col("g", &input.schema())?],
&input.schema(),
DataType::new_list(DataType::Int32, false),
)?;

let agg_expr_collectset = create_agg(
AggFunction::CollectSet,
&[phys_expr::col("h", &input.schema())?],
&input.schema(),
DataType::new_list(DataType::Int32, false),
)?;

let agg_expr_collectlist_nil = create_agg(
AggFunction::CollectList,
&[Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(None)))],
&input.schema(),
DataType::new_list(DataType::Utf8, false),
)?;

let agg_expr_collectset_nil = create_agg(
AggFunction::CollectSet,
&[Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(None)))],
&input.schema(),
DataType::new_list(DataType::Utf8, false),
)?;

let agg_expr_firstign = create_agg(
AggFunction::FirstIgnoresNull,
&[phys_expr::col("h", &input.schema())?],
&input.schema(),
DataType::Int32,
)?;

let aggs_agg_expr = vec![
Expand Down
11 changes: 10 additions & 1 deletion native-engine/datafusion-ext-plans/src/window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::sync::Arc;

use arrow::{array::ArrayRef, datatypes::FieldRef, record_batch::RecordBatch};
use arrow_schema::DataType;
use datafusion::{common::Result, physical_expr::PhysicalExpr};

use crate::{
Expand Down Expand Up @@ -53,18 +54,21 @@ pub struct WindowExpr {
field: FieldRef,
func: WindowFunction,
children: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
}

impl WindowExpr {
pub fn new(
func: WindowFunction,
children: Vec<Arc<dyn PhysicalExpr>>,
field: FieldRef,
return_type: DataType,
) -> Self {
Self {
field,
func,
children,
return_type,
}
}

Expand All @@ -83,7 +87,12 @@ impl WindowExpr {
Ok(Box::new(RankProcessor::new(true)))
}
WindowFunction::Agg(agg_func) => {
let agg = create_agg(agg_func, &self.children, &context.input_schema)?;
let agg = create_agg(
agg_func.clone(),
&self.children,
&context.input_schema,
self.return_type.clone(),
)?;
Ok(Box::new(AggProcessor::try_new(agg)?))
}
}
Expand Down
Loading
Loading