Skip to content

Commit 00bc3c2

Browse files
authored
feat(extensions): add std_dev and variance with distribution enum arg (#1011)
1 parent 7b39d4c commit 00bc3c2

4 files changed

Lines changed: 179 additions & 6 deletions

File tree

extensions/functions_arithmetic.yaml

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,13 +1403,33 @@ aggregate_functions:
14031403
nullability: DECLARED_OUTPUT
14041404
return: fp32?
14051405
- args:
1406+
- name: distribution
1407+
options: [ SAMPLE, POPULATION]
14061408
- name: x
1407-
value: fp64
1409+
value: fp32
14081410
options:
14091411
rounding:
14101412
values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ]
1413+
nullability: DECLARED_OUTPUT
1414+
return: fp32?
1415+
- args:
1416+
- name: x
1417+
value: fp64
1418+
options:
14111419
distribution:
14121420
values: [ SAMPLE, POPULATION]
1421+
rounding:
1422+
values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ]
1423+
nullability: DECLARED_OUTPUT
1424+
return: fp64?
1425+
- args:
1426+
- name: distribution
1427+
options: [ SAMPLE, POPULATION]
1428+
- name: x
1429+
value: fp64
1430+
options:
1431+
rounding:
1432+
values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ]
14131433
nullability: DECLARED_OUTPUT
14141434
return: fp64?
14151435
- name: "variance"
@@ -1419,20 +1439,40 @@ aggregate_functions:
14191439
- name: x
14201440
value: fp32
14211441
options:
1422-
rounding:
1423-
values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ]
14241442
distribution:
14251443
values: [ SAMPLE, POPULATION]
1444+
rounding:
1445+
values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ]
14261446
nullability: DECLARED_OUTPUT
14271447
return: fp32?
14281448
- args:
1449+
- name: distribution
1450+
options: [ SAMPLE, POPULATION]
14291451
- name: x
1430-
value: fp64
1452+
value: fp32
14311453
options:
14321454
rounding:
14331455
values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ]
1456+
nullability: DECLARED_OUTPUT
1457+
return: fp32?
1458+
- args:
1459+
- name: x
1460+
value: fp64
1461+
options:
14341462
distribution:
14351463
values: [ SAMPLE, POPULATION]
1464+
rounding:
1465+
values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ]
1466+
nullability: DECLARED_OUTPUT
1467+
return: fp64?
1468+
- args:
1469+
- name: distribution
1470+
options: [ SAMPLE, POPULATION]
1471+
- name: x
1472+
value: fp64
1473+
options:
1474+
rounding:
1475+
values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ]
14361476
nullability: DECLARED_OUTPUT
14371477
return: fp64?
14381478
- name: "corr"
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
### SUBSTRAIT_AGGREGATE_TEST: v1.0
2+
### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic.yaml'
3+
4+
# basic: Basic examples without any special cases
5+
((1.0, 2.0, 3.0, 4.0, 5.0)) std_dev(SAMPLE::enum, col0::fp32) = 1.5811388::fp32?
6+
((1.0, 2.0, 3.0, 4.0, 5.0)) std_dev(SAMPLE::enum, col0::fp64) = 1.5811388300841898::fp64?
7+
((1.0, 2.0, 3.0, 4.0, 5.0)) std_dev(POPULATION::enum, col0::fp32) = 1.4142135::fp32?
8+
((1.0, 2.0, 3.0, 4.0, 5.0)) std_dev(POPULATION::enum, col0::fp64) = 1.4142135623730951::fp64?
9+
10+
# uniform_values: Standard deviation of uniform values
11+
((5.0, 5.0, 5.0, 5.0)) std_dev(SAMPLE::enum, col0::fp32) = 0.0::fp32?
12+
((5.0, 5.0, 5.0, 5.0)) std_dev(POPULATION::enum, col0::fp64) = 0.0::fp64?
13+
14+
# single_value: Standard deviation with single value
15+
((42.0)) std_dev(SAMPLE::enum, col0::fp32) = Null::fp32?
16+
((42.0)) std_dev(POPULATION::enum, col0::fp64) = 0.0::fp64?
17+
18+
# negative_values: Standard deviation with negative values
19+
((-5.0, -3.0, -1.0, 1.0, 3.0, 5.0)) std_dev(SAMPLE::enum, col0::fp32) = 3.8944404::fp32?
20+
((-5.0, -3.0, -1.0, 1.0, 3.0, 5.0)) std_dev(SAMPLE::enum, col0::fp64) = 3.8944404818493075::fp64?
21+
((-10.0, -5.0, 0.0, 5.0, 10.0)) std_dev(POPULATION::enum, col0::fp32) = 7.0710678::fp32?
22+
((-10.0, -5.0, 0.0, 5.0, 10.0)) std_dev(POPULATION::enum, col0::fp64) = 7.0710678118654755::fp64?
23+
24+
# decimal_precision: Standard deviation with decimal values
25+
((1.5, 2.5, 3.5, 4.5, 5.5)) std_dev(SAMPLE::enum, col0::fp32) = 1.5811388::fp32?
26+
((1.5, 2.5, 3.5, 4.5, 5.5)) std_dev(SAMPLE::enum, col0::fp64) = 1.5811388300841898::fp64?
27+
((0.1, 0.2, 0.3, 0.4, 0.5)) std_dev(POPULATION::enum, col0::fp64) = 0.14142135623730953::fp64?
28+
29+
# large_values: Standard deviation with large values
30+
((1000.0, 2000.0, 3000.0, 4000.0, 5000.0)) std_dev(SAMPLE::enum, col0::fp32) = 1581.1388::fp32?
31+
((1000.0, 2000.0, 3000.0, 4000.0, 5000.0)) std_dev(SAMPLE::enum, col0::fp64) = 1581.1388300841898::fp64?
32+
33+
# small_values: Standard deviation with small values
34+
((0.001, 0.002, 0.003, 0.004, 0.005)) std_dev(SAMPLE::enum, col0::fp64) = 0.0015811388300841896::fp64?
35+
((0.001, 0.002, 0.003, 0.004, 0.005)) std_dev(POPULATION::enum, col0::fp64) = 0.0014142135623730951::fp64?
36+
37+
# null_handling: Examples with null as input or output
38+
((Null, Null, Null)) std_dev(SAMPLE::enum, col0::fp32?) = Null::fp32?
39+
(()) std_dev(SAMPLE::enum, col0::fp32) = Null::fp32?
40+
((1.0, Null, 3.0, Null, 5.0)) std_dev(SAMPLE::enum, col0::fp32?) = 2.0::fp32?
41+
((1.0, Null, 3.0, Null, 5.0)) std_dev(POPULATION::enum, col0::fp64?) = 1.632993161855452::fp64?
42+
43+
# rounding: Examples with different rounding modes
44+
((1.1, 2.2, 3.3, 4.4, 5.5)) std_dev(SAMPLE::enum, col0::fp32) [rounding:TIE_TO_EVEN] = 1.7406897::fp32?
45+
((1.1, 2.2, 3.3, 4.4, 5.5)) std_dev(SAMPLE::enum, col0::fp64) [rounding:TRUNCATE] = 1.7406897166664838::fp64?
46+
47+
# two_values: Standard deviation with two values
48+
((10.0, 20.0)) std_dev(SAMPLE::enum, col0::fp32) = 7.071068::fp32?
49+
((10.0, 20.0)) std_dev(POPULATION::enum, col0::fp64) = 5.0::fp64?
50+
51+
# mixed_range: Standard deviation with mixed range values
52+
((0.0, 100.0, 50.0, 25.0, 75.0)) std_dev(SAMPLE::enum, col0::fp32) = 41.010193::fp32?
53+
((0.0, 100.0, 50.0, 25.0, 75.0)) std_dev(POPULATION::enum, col0::fp64) = 36.66060555964672::fp64?
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
### SUBSTRAIT_AGGREGATE_TEST: v1.0
2+
### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic.yaml'
3+
4+
# basic: Basic examples without any special cases
5+
((1.0, 2.0, 3.0, 4.0, 5.0)) variance(SAMPLE::enum, col0::fp32) = 2.5::fp32?
6+
((1.0, 2.0, 3.0, 4.0, 5.0)) variance(SAMPLE::enum, col0::fp64) = 2.5::fp64?
7+
((1.0, 2.0, 3.0, 4.0, 5.0)) variance(POPULATION::enum, col0::fp32) = 2.0::fp32?
8+
((1.0, 2.0, 3.0, 4.0, 5.0)) variance(POPULATION::enum, col0::fp64) = 2.0::fp64?
9+
10+
# uniform_values: Variance of uniform values
11+
((5.0, 5.0, 5.0, 5.0)) variance(SAMPLE::enum, col0::fp32) = 0.0::fp32?
12+
((5.0, 5.0, 5.0, 5.0)) variance(POPULATION::enum, col0::fp64) = 0.0::fp64?
13+
14+
# single_value: Variance with single value
15+
((42.0)) variance(SAMPLE::enum, col0::fp32) = Null::fp32?
16+
((42.0)) variance(POPULATION::enum, col0::fp64) = 0.0::fp64?
17+
18+
# negative_values: Variance with negative values
19+
((-5.0, -3.0, -1.0, 1.0, 3.0, 5.0)) variance(SAMPLE::enum, col0::fp32) = 15.166667::fp32?
20+
((-5.0, -3.0, -1.0, 1.0, 3.0, 5.0)) variance(SAMPLE::enum, col0::fp64) = 15.166666666666666::fp64?
21+
((-10.0, -5.0, 0.0, 5.0, 10.0)) variance(POPULATION::enum, col0::fp32) = 50.0::fp32?
22+
((-10.0, -5.0, 0.0, 5.0, 10.0)) variance(POPULATION::enum, col0::fp64) = 50.0::fp64?
23+
24+
# decimal_precision: Variance with decimal values
25+
((1.5, 2.5, 3.5, 4.5, 5.5)) variance(SAMPLE::enum, col0::fp32) = 2.5::fp32?
26+
((1.5, 2.5, 3.5, 4.5, 5.5)) variance(SAMPLE::enum, col0::fp64) = 2.5::fp64?
27+
((0.1, 0.2, 0.3, 0.4, 0.5)) variance(POPULATION::enum, col0::fp64) = 0.020000000000000004::fp64?
28+
29+
# large_values: Variance with large values
30+
((1000.0, 2000.0, 3000.0, 4000.0, 5000.0)) variance(SAMPLE::enum, col0::fp32) = 2500000.0::fp32?
31+
((1000.0, 2000.0, 3000.0, 4000.0, 5000.0)) variance(SAMPLE::enum, col0::fp64) = 2500000.0::fp64?
32+
33+
# small_values: Variance with small values
34+
((0.001, 0.002, 0.003, 0.004, 0.005)) variance(SAMPLE::enum, col0::fp64) = 0.0000025::fp64?
35+
((0.001, 0.002, 0.003, 0.004, 0.005)) variance(POPULATION::enum, col0::fp64) = 0.000002::fp64?
36+
37+
# null_handling: Examples with null as input or output
38+
((Null, Null, Null)) variance(SAMPLE::enum, col0::fp32?) = Null::fp32?
39+
(()) variance(SAMPLE::enum, col0::fp32) = Null::fp32?
40+
((1.0, Null, 3.0, Null, 5.0)) variance(SAMPLE::enum, col0::fp32?) = 4.0::fp32?
41+
((1.0, Null, 3.0, Null, 5.0)) variance(POPULATION::enum, col0::fp64?) = 2.666666666666667::fp64?
42+
43+
# rounding: Examples with different rounding modes
44+
((1.1, 2.2, 3.3, 4.4, 5.5)) variance(SAMPLE::enum, col0::fp32) [rounding:TIE_TO_EVEN] = 3.03::fp32?
45+
((1.1, 2.2, 3.3, 4.4, 5.5)) variance(SAMPLE::enum, col0::fp64) [rounding:TRUNCATE] = 3.0299999999999994::fp64?
46+
47+
# two_values: Variance with two values
48+
((10.0, 20.0)) variance(SAMPLE::enum, col0::fp32) = 50.0::fp32?
49+
((10.0, 20.0)) variance(POPULATION::enum, col0::fp64) = 25.0::fp64?
50+
51+
# mixed_range: Variance with mixed range values
52+
((0.0, 100.0, 50.0, 25.0, 75.0)) variance(SAMPLE::enum, col0::fp32) = 1681.25::fp32?
53+
((0.0, 100.0, 50.0, 25.0, 75.0)) variance(POPULATION::enum, col0::fp64) = 1345.0::fp64?
54+
55+
# zero_mean: Variance with values around zero
56+
((-2.0, -1.0, 0.0, 1.0, 2.0)) variance(SAMPLE::enum, col0::fp32) = 2.5::fp32?
57+
((-2.0, -1.0, 0.0, 1.0, 2.0)) variance(POPULATION::enum, col0::fp64) = 2.0::fp64?
58+
59+
# three_values: Variance with three values
60+
((10.0, 20.0, 30.0)) variance(SAMPLE::enum, col0::fp32) = 100.0::fp32?
61+
((10.0, 20.0, 30.0)) variance(POPULATION::enum, col0::fp64) = 66.66666666666667::fp64?

tests/coverage/nodes.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,29 @@ def is_return_type_error(self):
8585
return isinstance(self.result, SubstraitError)
8686

8787
def get_arg_types(self):
88-
return [arg.get_base_type() for arg in self.args]
88+
types = []
89+
for arg in self.args:
90+
if isinstance(arg, CaseLiteral):
91+
types.append(arg.get_base_type())
92+
elif isinstance(arg, AggregateArgument):
93+
# For aggregate arguments, use column_type if available, otherwise extract from scalar_value
94+
if arg.column_type:
95+
types.append(arg.column_type)
96+
elif arg.scalar_value:
97+
types.append(arg.scalar_value.get_base_type())
98+
return types
8999

90100
def get_signature(self):
91-
return f"{self.func_name}({', '.join([arg.type for arg in self.args])}) = {self.get_return_type()}"
101+
arg_types = []
102+
for arg in self.args:
103+
if isinstance(arg, CaseLiteral):
104+
arg_types.append(arg.type)
105+
elif isinstance(arg, AggregateArgument):
106+
if arg.column_type:
107+
arg_types.append(arg.column_type)
108+
elif arg.scalar_value:
109+
arg_types.append(arg.scalar_value.type)
110+
return f"{self.func_name}({', '.join(arg_types)}) = {self.get_return_type()}"
92111

93112

94113
@dataclass

0 commit comments

Comments
 (0)