Skip to content

Commit f5eeb07

Browse files
authored
Refactor hash key equality function (#7969)
This was previously done in several places in a somewhat convoluted way, which probably made sense at some point. As expected however, a few small local variations had emerged, and while nothing critical, this code wasn't very nice to work with. Some comments stated as the rationale for the design was avoiding allocations, but those were nowhere to be seen when measured now, meaning there was no good reason to have it remain this way! It was *quite* nice to be able to merge the numbers comparsion functions (in particular) together into one! We now have a unified way for comparing Number values throughout the AST package, and as an added bonus, `1.0 == 1` is now true consistently for Rego. See for example @srenatus example in #4797 ``` $ opa eval -fpretty 'count({1.0, 1})' 2 ``` Doing the same now gives: ``` $ go run main.go eval -fpretty 'count({1.0, 1})' 1 $ go run main.go eval -fpretty 'count({1.0, 1, 1.000, 1.00000})' 1 ``` What I have left out for now is however _presentation_. Meaning that even though 1.0 and 1 is now treated as the same value, you may still see either '1' or 1.0' (or whatever) displayed, depending on what was parsed. Should be easy to fix, but could perhaps be perceived as surprising... so holding off on that until we've had a discussion on the topic. Signed-off-by: Anders Eknert <anders@eknert.com>
1 parent 7c5ccbe commit f5eeb07

5 files changed

Lines changed: 286 additions & 527 deletions

File tree

internal/edittree/edittree.go

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ package edittree
148148
import (
149149
"errors"
150150
"fmt"
151-
"math/big"
152151
"sort"
153152
"strings"
154153

@@ -203,89 +202,13 @@ func NewEditTree(term *ast.Term) *EditTree {
203202
// it was found in the table already.
204203
func (e *EditTree) getKeyHash(key *ast.Term) (int, bool) {
205204
hash := key.Hash()
206-
// This `equal` utility is duplicated and manually inlined a number of
207-
// time in this file. Inlining it avoids heap allocations, so it makes
208-
// a big performance difference: some operations like lookup become twice
209-
// as slow without it.
210-
var equal func(v ast.Value) bool
211-
212-
switch x := key.Value.(type) {
213-
case ast.Null, ast.Boolean, ast.String, ast.Var:
214-
equal = func(y ast.Value) bool { return x == y }
215-
case ast.Number:
216-
if xi, ok := x.Int64(); ok {
217-
equal = func(y ast.Value) bool {
218-
if y, ok := y.(ast.Number); ok {
219-
if yi, ok := y.Int64(); ok {
220-
return xi == yi
221-
}
222-
}
223-
224-
return false
225-
}
226-
break
227-
}
228-
229-
// We use big.Rat for comparing big numbers.
230-
// It replaces big.Float due to following reason:
231-
// big.Float comes with a default precision of 64, and setting a
232-
// larger precision results in more memory being allocated
233-
// (regardless of the actual number we are parsing with SetString).
234-
//
235-
// Note: If we're so close to zero that big.Float says we are zero, do
236-
// *not* big.Rat).SetString on the original string it'll potentially
237-
// take very long.
238-
var a *big.Rat
239-
fa, ok := new(big.Float).SetString(string(x))
240-
if !ok {
241-
panic("illegal value")
242-
}
243-
if fa.IsInt() {
244-
if i, _ := fa.Int64(); i == 0 {
245-
a = new(big.Rat).SetInt64(0)
246-
}
247-
}
248-
if a == nil {
249-
a, ok = new(big.Rat).SetString(string(x))
250-
if !ok {
251-
panic("illegal value")
252-
}
253-
}
254-
255-
equal = func(b ast.Value) bool {
256-
if bNum, ok := b.(ast.Number); ok {
257-
var b *big.Rat
258-
fb, ok := new(big.Float).SetString(string(bNum))
259-
if !ok {
260-
panic("illegal value")
261-
}
262-
if fb.IsInt() {
263-
if i, _ := fb.Int64(); i == 0 {
264-
b = new(big.Rat).SetInt64(0)
265-
}
266-
}
267-
if b == nil {
268-
b, ok = new(big.Rat).SetString(string(bNum))
269-
if !ok {
270-
panic("illegal value")
271-
}
272-
}
273-
274-
return a.Cmp(b) == 0
275-
}
276-
return false
277-
}
278-
279-
default:
280-
equal = func(y ast.Value) bool { return ast.Compare(x, y) == 0 }
281-
}
282205

283206
// Look through childKeys, looking up the original hash
284207
// value first, and then use linear-probing to iter
285208
// through the keys until we either find the Term we're
286209
// after, or run out of candidates.
287210
for curr, ok := e.childKeys[hash]; ok; {
288-
if equal(curr.Value) {
211+
if ast.KeyHashEqual(curr.Value, key.Value) {
289212
return hash, true
290213
}
291214

v1/ast/compare.go

Lines changed: 102 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
package ast
66

77
import (
8-
"encoding/json"
8+
"cmp"
99
"fmt"
1010
"math/big"
11+
"strings"
1112
)
1213

1314
// Compare returns an integer indicating whether two AST values are less than,
@@ -77,73 +78,18 @@ func Compare(a, b any) int {
7778
case Null:
7879
return 0
7980
case Boolean:
80-
b := b.(Boolean)
81-
if a.Equal(b) {
81+
if a == b.(Boolean) {
8282
return 0
8383
}
8484
if !a {
8585
return -1
8686
}
8787
return 1
8888
case Number:
89-
if ai, err := json.Number(a).Int64(); err == nil {
90-
if bi, err := json.Number(b.(Number)).Int64(); err == nil {
91-
if ai == bi {
92-
return 0
93-
}
94-
if ai < bi {
95-
return -1
96-
}
97-
return 1
98-
}
99-
}
100-
101-
// We use big.Rat for comparing big numbers.
102-
// It replaces big.Float due to following reason:
103-
// big.Float comes with a default precision of 64, and setting a
104-
// larger precision results in more memory being allocated
105-
// (regardless of the actual number we are parsing with SetString).
106-
//
107-
// Note: If we're so close to zero that big.Float says we are zero, do
108-
// *not* big.Rat).SetString on the original string it'll potentially
109-
// take very long.
110-
var bigA, bigB *big.Rat
111-
fa, ok := new(big.Float).SetString(string(a))
112-
if !ok {
113-
panic("illegal value")
114-
}
115-
if fa.IsInt() {
116-
if i, _ := fa.Int64(); i == 0 {
117-
bigA = new(big.Rat).SetInt64(0)
118-
}
119-
}
120-
if bigA == nil {
121-
bigA, ok = new(big.Rat).SetString(string(a))
122-
if !ok {
123-
panic("illegal value")
124-
}
125-
}
126-
127-
fb, ok := new(big.Float).SetString(string(b.(Number)))
128-
if !ok {
129-
panic("illegal value")
130-
}
131-
if fb.IsInt() {
132-
if i, _ := fb.Int64(); i == 0 {
133-
bigB = new(big.Rat).SetInt64(0)
134-
}
135-
}
136-
if bigB == nil {
137-
bigB, ok = new(big.Rat).SetString(string(b.(Number)))
138-
if !ok {
139-
panic("illegal value")
140-
}
141-
}
142-
143-
return bigA.Cmp(bigB)
89+
return NumberCompare(a, b.(Number))
14490
case String:
14591
b := b.(String)
146-
if a.Equal(b) {
92+
if a == b {
14793
return 0
14894
}
14995
if a < b {
@@ -153,8 +99,7 @@ func Compare(a, b any) int {
15399
case Var:
154100
return VarCompare(a, b.(Var))
155101
case Ref:
156-
b := b.(Ref)
157-
return termSliceCompare(a, b)
102+
return termSliceCompare(a, b.(Ref))
158103
case *Array:
159104
b := b.(*Array)
160105
return termSliceCompare(a.elems, b.elems)
@@ -164,11 +109,9 @@ func Compare(a, b any) int {
164109
if x, ok := b.(*lazyObj); ok {
165110
b = x.force()
166111
}
167-
b := b.(*object)
168-
return a.Compare(b)
112+
return a.Compare(b.(*object))
169113
case Set:
170-
b := b.(Set)
171-
return a.Compare(b)
114+
return a.Compare(b.(Set))
172115
case *ArrayComprehension:
173116
b := b.(*ArrayComprehension)
174117
if cmp := Compare(a.Term, b.Term); cmp != 0 {
@@ -191,44 +134,31 @@ func Compare(a, b any) int {
191134
}
192135
return a.Body.Compare(b.Body)
193136
case Call:
194-
b := b.(Call)
195-
return termSliceCompare(a, b)
137+
return termSliceCompare(a, b.(Call))
196138
case *Expr:
197-
b := b.(*Expr)
198-
return a.Compare(b)
139+
return a.Compare(b.(*Expr))
199140
case *SomeDecl:
200-
b := b.(*SomeDecl)
201-
return a.Compare(b)
141+
return a.Compare(b.(*SomeDecl))
202142
case *Every:
203-
b := b.(*Every)
204-
return a.Compare(b)
143+
return a.Compare(b.(*Every))
205144
case *With:
206-
b := b.(*With)
207-
return a.Compare(b)
145+
return a.Compare(b.(*With))
208146
case Body:
209-
b := b.(Body)
210-
return a.Compare(b)
147+
return a.Compare(b.(Body))
211148
case *Head:
212-
b := b.(*Head)
213-
return a.Compare(b)
149+
return a.Compare(b.(*Head))
214150
case *Rule:
215-
b := b.(*Rule)
216-
return a.Compare(b)
151+
return a.Compare(b.(*Rule))
217152
case Args:
218-
b := b.(Args)
219-
return termSliceCompare(a, b)
153+
return termSliceCompare(a, b.(Args))
220154
case *Import:
221-
b := b.(*Import)
222-
return a.Compare(b)
155+
return a.Compare(b.(*Import))
223156
case *Package:
224-
b := b.(*Package)
225-
return a.Compare(b)
157+
return a.Compare(b.(*Package))
226158
case *Annotations:
227-
b := b.(*Annotations)
228-
return a.Compare(b)
159+
return a.Compare(b.(*Annotations))
229160
case *Module:
230-
b := b.(*Module)
231-
return a.Compare(b)
161+
return a.Compare(b.(*Module))
232162
}
233163
panic(fmt.Sprintf("illegal value: %T", a))
234164
}
@@ -427,3 +357,84 @@ func RefCompare(a, b Ref) int {
427357
func RefEqual(a, b Ref) bool {
428358
return termSliceEqual(a, b)
429359
}
360+
361+
func NumberCompare(x, y Number) int {
362+
xs, ys := string(x), string(y)
363+
364+
var xIsF, yIsF bool
365+
366+
// Treat "1" and "1.0", "1.00", etc as "1"
367+
if strings.Contains(xs, ".") {
368+
if tx := strings.TrimRight(xs, ".0"); tx != xs {
369+
// Still a float after trimming?
370+
xIsF = strings.Contains(tx, ".")
371+
xs = tx
372+
}
373+
}
374+
if strings.Contains(ys, ".") {
375+
if ty := strings.TrimRight(ys, ".0"); ty != ys {
376+
yIsF = strings.Contains(ty, ".")
377+
ys = ty
378+
}
379+
}
380+
if xs == ys {
381+
return 0
382+
}
383+
384+
var xi, yi int64
385+
var xf, yf float64
386+
var xiOK, yiOK, xfOK, yfOK bool
387+
388+
if xi, xiOK = x.Int64(); xiOK {
389+
if yi, yiOK = y.Int64(); yiOK {
390+
return cmp.Compare(xi, yi)
391+
}
392+
}
393+
394+
if xIsF && yIsF {
395+
if xf, xfOK = x.Float64(); xfOK {
396+
if yf, yfOK = y.Float64(); yfOK {
397+
if xf == yf {
398+
return 0
399+
}
400+
// could still be "equal" depending on precision, so we continue?
401+
}
402+
}
403+
}
404+
405+
var a *big.Rat
406+
fa, ok := new(big.Float).SetString(string(x))
407+
if !ok {
408+
panic("illegal value")
409+
}
410+
if fa.IsInt() {
411+
if i, _ := fa.Int64(); i == 0 {
412+
a = new(big.Rat).SetInt64(0)
413+
}
414+
}
415+
if a == nil {
416+
a, ok = new(big.Rat).SetString(string(x))
417+
if !ok {
418+
panic("illegal value")
419+
}
420+
}
421+
422+
var b *big.Rat
423+
fb, ok := new(big.Float).SetString(string(y))
424+
if !ok {
425+
panic("illegal value")
426+
}
427+
if fb.IsInt() {
428+
if i, _ := fb.Int64(); i == 0 {
429+
b = new(big.Rat).SetInt64(0)
430+
}
431+
}
432+
if b == nil {
433+
b, ok = new(big.Rat).SetString(string(y))
434+
if !ok {
435+
panic("illegal value")
436+
}
437+
}
438+
439+
return a.Cmp(b)
440+
}

v1/ast/compare_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,15 @@ func TestCompare(t *testing.T) {
3030
{"0", "1", -1},
3131
{"1", "0", 1},
3232
{"0", "0", 0},
33+
{"0.0", "0", 0},
34+
{"1.1", "1.10", 0},
35+
{"1", "1.0000000000000000000000000000000000000000000", 0},
36+
{"1.0", "1.0000000000000000000000000000000000000000000", 0},
37+
{"1.000000000000000", "1.0000000000000000000000000000000000000000000", 0},
38+
{"1.10", "1.11", -1},
3339
{"0", "1.5", -1},
3440
{"1.5", "0", 1},
41+
{"100000", "100", 1},
3542
{"123456789123456789123", "123456789123456789123", 0},
3643
{"123456789123456789123", "123456789123456789122", 1},
3744
{"123456789123456789122", "123456789123456789123", -1},

0 commit comments

Comments
 (0)