Skip to content

Commit 87dca5f

Browse files
goldmedalfindepi
authored andcommitted
Convert stddev and stddev_pop to UDAF (apache#10834)
* add stddev and stddev_pot udaf * remove aggregation function stddev and stddev_pop * register func and modified return type * cargo fmt * regen proto * cargo clippy * fix window function support * cargo fmt * throw not_impl_err instead * use default sliding accumulator
1 parent 50a1c6c commit 87dca5f

17 files changed

Lines changed: 389 additions & 469 deletions

File tree

datafusion/core/src/dataframe/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ use datafusion_common::{
5050
};
5151
use datafusion_expr::lit;
5252
use datafusion_expr::{
53-
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
54-
TableProviderFilterPushDown, UNNAMED_TABLE,
53+
avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
54+
UNNAMED_TABLE,
5555
};
5656
use datafusion_expr::{case, is_null};
57-
use datafusion_functions_aggregate::expr_fn::median;
5857
use datafusion_functions_aggregate::expr_fn::sum;
58+
use datafusion_functions_aggregate::expr_fn::{median, stddev};
5959

6060
use async_trait::async_trait;
6161

datafusion/expr/src/aggregate_function.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,6 @@ pub enum AggregateFunction {
4949
NthValue,
5050
/// Variance (Population)
5151
VariancePop,
52-
/// Standard Deviation (Sample)
53-
Stddev,
54-
/// Standard Deviation (Population)
55-
StddevPop,
5652
/// Correlation
5753
Correlation,
5854
/// Slope from linear regression
@@ -107,8 +103,6 @@ impl AggregateFunction {
107103
ArrayAgg => "ARRAY_AGG",
108104
NthValue => "NTH_VALUE",
109105
VariancePop => "VAR_POP",
110-
Stddev => "STDDEV",
111-
StddevPop => "STDDEV_POP",
112106
Correlation => "CORR",
113107
RegrSlope => "REGR_SLOPE",
114108
RegrIntercept => "REGR_INTERCEPT",
@@ -159,9 +153,6 @@ impl FromStr for AggregateFunction {
159153
"string_agg" => AggregateFunction::StringAgg,
160154
// statistical
161155
"corr" => AggregateFunction::Correlation,
162-
"stddev" => AggregateFunction::Stddev,
163-
"stddev_pop" => AggregateFunction::StddevPop,
164-
"stddev_samp" => AggregateFunction::Stddev,
165156
"var_pop" => AggregateFunction::VariancePop,
166157
"regr_slope" => AggregateFunction::RegrSlope,
167158
"regr_intercept" => AggregateFunction::RegrIntercept,
@@ -231,8 +222,6 @@ impl AggregateFunction {
231222
AggregateFunction::Correlation => {
232223
correlation_return_type(&coerced_data_types[0])
233224
}
234-
AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
235-
AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
236225
AggregateFunction::RegrSlope
237226
| AggregateFunction::RegrIntercept
238227
| AggregateFunction::RegrCount
@@ -304,8 +293,6 @@ impl AggregateFunction {
304293
}
305294
AggregateFunction::Avg
306295
| AggregateFunction::VariancePop
307-
| AggregateFunction::Stddev
308-
| AggregateFunction::StddevPop
309296
| AggregateFunction::ApproxMedian => {
310297
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
311298
}

datafusion/expr/src/expr_fn.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -383,18 +383,6 @@ pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
383383
})
384384
}
385385

386-
/// Create an expression to represent the stddev() aggregate function
387-
pub fn stddev(expr: Expr) -> Expr {
388-
Expr::AggregateFunction(AggregateFunction::new(
389-
aggregate_function::AggregateFunction::Stddev,
390-
vec![expr],
391-
false,
392-
None,
393-
None,
394-
None,
395-
))
396-
}
397-
398386
/// Create a grouping set
399387
pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
400388
Expr::GroupingSet(GroupingSet::GroupingSets(exprs))

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,6 @@ pub fn coerce_types(
161161
}
162162
Ok(vec![Float64, Float64])
163163
}
164-
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
165-
if !is_stddev_support_arg_type(&input_types[0]) {
166-
return plan_err!(
167-
"The function {:?} does not support inputs of type {:?}.",
168-
agg_fun,
169-
input_types[0]
170-
);
171-
}
172-
Ok(vec![Float64])
173-
}
174164
AggregateFunction::Correlation => {
175165
if !is_correlation_support_arg_type(&input_types[0]) {
176166
return plan_err!(
@@ -408,15 +398,6 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
408398
}
409399
}
410400

411-
/// function return type of standard deviation
412-
pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> {
413-
if NUMERICS.contains(arg_type) {
414-
Ok(DataType::Float64)
415-
} else {
416-
plan_err!("STDDEV does not support {arg_type:?}")
417-
}
418-
}
419-
420401
/// function return type of an average
421402
pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
422403
match arg_type {
@@ -511,13 +492,6 @@ pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
511492
)
512493
}
513494

514-
pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool {
515-
matches!(
516-
arg_type,
517-
arg_type if NUMERICS.contains(arg_type)
518-
)
519-
}
520-
521495
pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
522496
matches!(
523497
arg_type,
@@ -664,17 +638,6 @@ mod tests {
664638
Ok(())
665639
}
666640

667-
#[test]
668-
fn test_stddev_return_data_type() -> Result<()> {
669-
let data_type = DataType::Float64;
670-
let result_type = stddev_return_type(&data_type)?;
671-
assert_eq!(DataType::Float64, result_type);
672-
673-
let data_type = DataType::Decimal128(36, 10);
674-
assert!(stddev_return_type(&data_type).is_err());
675-
Ok(())
676-
}
677-
678641
#[test]
679642
fn test_covariance_return_data_type() -> Result<()> {
680643
let data_type = DataType::Float64;

datafusion/functions-aggregate/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ pub mod macros;
5858
pub mod covariance;
5959
pub mod first_last;
6060
pub mod median;
61+
pub mod stddev;
6162
pub mod sum;
6263
pub mod variance;
6364

@@ -74,6 +75,8 @@ pub mod expr_fn {
7475
pub use super::first_last::first_value;
7576
pub use super::first_last::last_value;
7677
pub use super::median::median;
78+
pub use super::stddev::stddev;
79+
pub use super::stddev::stddev_pop;
7780
pub use super::sum::sum;
7881
pub use super::variance::var_sample;
7982
}
@@ -88,6 +91,8 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
8891
covariance::covar_pop_udaf(),
8992
median::median_udaf(),
9093
variance::var_samp_udaf(),
94+
stddev::stddev_udaf(),
95+
stddev::stddev_pop_udaf(),
9196
]
9297
}
9398

0 commit comments

Comments
 (0)