1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: util:: pretty:: pretty_format_columns;
18+ use arrow:: util:: pretty:: { pretty_format_batches , pretty_format_columns} ;
1919use arrow_array:: builder:: { ListBuilder , StringBuilder } ;
20- use arrow_array:: { ArrayRef , RecordBatch , StringArray , StructArray } ;
20+ use arrow_array:: { ArrayRef , Int64Array , RecordBatch , StringArray , StructArray } ;
2121use arrow_schema:: { DataType , Field } ;
2222use datafusion:: prelude:: * ;
23- use datafusion_common:: { DFSchema , ScalarValue } ;
23+ use datafusion_common:: { assert_contains, DFSchema , ScalarValue } ;
24+ use datafusion_expr:: AggregateExt ;
2425use datafusion_functions:: core:: expr_ext:: FieldAccessor ;
26+ use datafusion_functions_aggregate:: first_last:: first_value_udaf;
27+ use datafusion_functions_aggregate:: sum:: sum_udaf;
2528use datafusion_functions_array:: expr_ext:: { IndexAccessor , SliceAccessor } ;
29+ use sqlparser:: ast:: NullTreatment ;
2630/// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan
2731use std:: sync:: { Arc , OnceLock } ;
2832
@@ -162,6 +166,183 @@ fn test_list_range() {
162166 ) ;
163167}
164168
169+ #[ tokio:: test]
170+ async fn test_aggregate_error ( ) {
171+ let err = first_value_udaf ( )
172+ . call ( vec ! [ col( "props" ) ] )
173+ // not a sort column
174+ . order_by ( vec ! [ col( "id" ) ] )
175+ . build ( )
176+ . unwrap_err ( )
177+ . to_string ( ) ;
178+ assert_contains ! (
179+ err,
180+ "Error during planning: ORDER BY expressions must be Expr::Sort"
181+ ) ;
182+ }
183+
184+ #[ tokio:: test]
185+ async fn test_aggregate_ext_order_by ( ) {
186+ let agg = first_value_udaf ( ) . call ( vec ! [ col( "props" ) ] ) ;
187+
188+ // ORDER BY id ASC
189+ let agg_asc = agg
190+ . clone ( )
191+ . order_by ( vec ! [ col( "id" ) . sort( true , true ) ] )
192+ . build ( )
193+ . unwrap ( )
194+ . alias ( "asc" ) ;
195+
196+ // ORDER BY id DESC
197+ let agg_desc = agg
198+ . order_by ( vec ! [ col( "id" ) . sort( false , true ) ] )
199+ . build ( )
200+ . unwrap ( )
201+ . alias ( "desc" ) ;
202+
203+ evaluate_agg_test (
204+ agg_asc,
205+ vec ! [
206+ "+-----------------+" ,
207+ "| asc |" ,
208+ "+-----------------+" ,
209+ "| {a: 2021-02-01} |" ,
210+ "+-----------------+" ,
211+ ] ,
212+ )
213+ . await ;
214+
215+ evaluate_agg_test (
216+ agg_desc,
217+ vec ! [
218+ "+-----------------+" ,
219+ "| desc |" ,
220+ "+-----------------+" ,
221+ "| {a: 2021-02-03} |" ,
222+ "+-----------------+" ,
223+ ] ,
224+ )
225+ . await ;
226+ }
227+
228+ #[ tokio:: test]
229+ async fn test_aggregate_ext_filter ( ) {
230+ let agg = first_value_udaf ( )
231+ . call ( vec ! [ col( "i" ) ] )
232+ . order_by ( vec ! [ col( "i" ) . sort( true , true ) ] )
233+ . filter ( col ( "i" ) . is_not_null ( ) )
234+ . build ( )
235+ . unwrap ( )
236+ . alias ( "val" ) ;
237+
238+ #[ rustfmt:: skip]
239+ evaluate_agg_test (
240+ agg,
241+ vec ! [
242+ "+-----+" ,
243+ "| val |" ,
244+ "+-----+" ,
245+ "| 5 |" ,
246+ "+-----+" ,
247+ ] ,
248+ )
249+ . await ;
250+ }
251+
252+ #[ tokio:: test]
253+ async fn test_aggregate_ext_distinct ( ) {
254+ let agg = sum_udaf ( )
255+ . call ( vec ! [ lit( 5 ) ] )
256+ // distinct sum should be 5, not 15
257+ . distinct ( )
258+ . build ( )
259+ . unwrap ( )
260+ . alias ( "distinct" ) ;
261+
262+ evaluate_agg_test (
263+ agg,
264+ vec ! [
265+ "+----------+" ,
266+ "| distinct |" ,
267+ "+----------+" ,
268+ "| 5 |" ,
269+ "+----------+" ,
270+ ] ,
271+ )
272+ . await ;
273+ }
274+
275+ #[ tokio:: test]
276+ async fn test_aggregate_ext_null_treatment ( ) {
277+ let agg = first_value_udaf ( )
278+ . call ( vec ! [ col( "i" ) ] )
279+ . order_by ( vec ! [ col( "i" ) . sort( true , true ) ] ) ;
280+
281+ let agg_respect = agg
282+ . clone ( )
283+ . null_treatment ( NullTreatment :: RespectNulls )
284+ . build ( )
285+ . unwrap ( )
286+ . alias ( "respect" ) ;
287+
288+ let agg_ignore = agg
289+ . null_treatment ( NullTreatment :: IgnoreNulls )
290+ . build ( )
291+ . unwrap ( )
292+ . alias ( "ignore" ) ;
293+
294+ evaluate_agg_test (
295+ agg_respect,
296+ vec ! [
297+ "+---------+" ,
298+ "| respect |" ,
299+ "+---------+" ,
300+ "| |" ,
301+ "+---------+" ,
302+ ] ,
303+ )
304+ . await ;
305+
306+ evaluate_agg_test (
307+ agg_ignore,
308+ vec ! [
309+ "+--------+" ,
310+ "| ignore |" ,
311+ "+--------+" ,
312+ "| 5 |" ,
313+ "+--------+" ,
314+ ] ,
315+ )
316+ . await ;
317+ }
318+
319+ /// Evaluates the specified expr as an aggregate and compares the result to the
320+ /// expected result.
321+ async fn evaluate_agg_test ( expr : Expr , expected_lines : Vec < & str > ) {
322+ let batch = test_batch ( ) ;
323+
324+ let ctx = SessionContext :: new ( ) ;
325+ let group_expr = vec ! [ ] ;
326+ let agg_expr = vec ! [ expr] ;
327+ let result = ctx
328+ . read_batch ( batch)
329+ . unwrap ( )
330+ . aggregate ( group_expr, agg_expr)
331+ . unwrap ( )
332+ . collect ( )
333+ . await
334+ . unwrap ( ) ;
335+
336+ let result = pretty_format_batches ( & result) . unwrap ( ) . to_string ( ) ;
337+ let actual_lines = result. lines ( ) . collect :: < Vec < _ > > ( ) ;
338+
339+ assert_eq ! (
340+ expected_lines, actual_lines,
341+ "\n \n expected:\n \n {:#?}\n actual:\n \n {:#?}\n \n " ,
342+ expected_lines, actual_lines
343+ ) ;
344+ }
345+
165346/// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided
166347/// `RecordBatch` and compares the result to the expected result.
167348fn evaluate_expr_test ( expr : Expr , expected_lines : Vec < & str > ) {
@@ -189,6 +370,8 @@ fn test_batch() -> RecordBatch {
189370 TEST_BATCH
190371 . get_or_init ( || {
191372 let string_array: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ "1" , "2" , "3" ] ) ) ;
373+ let int_array: ArrayRef =
374+ Arc :: new ( Int64Array :: from_iter ( vec ! [ Some ( 10 ) , None , Some ( 5 ) ] ) ) ;
192375
193376 // { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" }
194377 let struct_array: ArrayRef = Arc :: from ( StructArray :: from ( vec ! [ (
@@ -209,6 +392,7 @@ fn test_batch() -> RecordBatch {
209392
210393 RecordBatch :: try_from_iter ( vec ! [
211394 ( "id" , string_array) ,
395+ ( "i" , int_array) ,
212396 ( "props" , struct_array) ,
213397 ( "list" , list_array) ,
214398 ] )
0 commit comments