Skip to content

Commit 42c448e

Browse files
authored
feat(arrow/compute): implement "is_in" function (#319)
### Rationale for this change Since we use arrow-go for iceberg-go and utilize the compute libraries for filtering, we need to ensure we add support for the minimal number of functions that iceberg requires. ### What changes are included in this PR? Implementing the `is_in` function for the function registry, registering it by default, and ensuring we also allow using `is_in` from substrait. ### Are these changes tested? Yes, unit tests are included. ### Are there any user-facing changes? There shouldn't be any user-facing changes.
1 parent a0206ec commit 42c448e

22 files changed

Lines changed: 1441 additions & 152 deletions

arrow/array/builder.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,13 @@ func (b *builder) resize(newBits int, init func(int)) {
176176
}
177177

178178
func (b *builder) reserve(elements int, resize func(int)) {
179-
if b.nullBitmap == nil {
180-
b.nullBitmap = memory.NewResizableBuffer(b.mem)
181-
}
182179
if b.length+elements > b.capacity {
183180
newCap := bitutil.NextPowerOf2(b.length + elements)
184181
resize(newCap)
185182
}
183+
if b.nullBitmap == nil {
184+
b.nullBitmap = memory.NewResizableBuffer(b.mem)
185+
}
186186
}
187187

188188
// unsafeAppendBoolsToBitmap appends the contents of valid to the validity bitmap.

arrow/array/float16.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,6 @@ func (a *Float16) MarshalJSON() ([]byte, error) {
106106
return json.Marshal(vals)
107107
}
108108

109-
func arrayEqualFloat16(left, right *Float16) bool {
110-
for i := 0; i < left.Len(); i++ {
111-
if left.IsNull(i) {
112-
continue
113-
}
114-
if left.Value(i) != right.Value(i) {
115-
return false
116-
}
117-
}
118-
return true
119-
}
120-
121109
var (
122110
_ arrow.Array = (*Float16)(nil)
123111
_ arrow.TypedArray[float16.Num] = (*Float16)(nil)

arrow/compute/arithmetic_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ type BinaryArithmeticSuite[T arrow.NumericType] struct {
204204
scalarEqualOpts []scalar.EqualOption
205205
}
206206

207-
func (BinaryArithmeticSuite[T]) DataType() arrow.DataType {
207+
func (*BinaryArithmeticSuite[T]) DataType() arrow.DataType {
208208
return arrow.GetDataType[T]()
209209
}
210210

arrow/compute/exec/kernel.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ type NonAggKernel interface {
6868
GetNullHandling() NullHandling
6969
GetMemAlloc() MemAlloc
7070
CanFillSlices() bool
71+
Cleanup() error
7172
}
7273

7374
// KernelCtx is a small struct holding the context for a kernel execution
@@ -604,6 +605,7 @@ type ScalarKernel struct {
604605
CanWriteIntoSlices bool
605606
NullHandling NullHandling
606607
MemAlloc MemAlloc
608+
CleanupFn func(KernelState) error
607609
}
608610

609611
// NewScalarKernel constructs a new kernel for scalar execution, constructing
@@ -629,6 +631,13 @@ func NewScalarKernelWithSig(sig *KernelSignature, exec ArrayKernelExec, init Ker
629631
}
630632
}
631633

634+
func (s *ScalarKernel) Cleanup() error {
635+
if s.CleanupFn != nil {
636+
return s.CleanupFn(s.Data)
637+
}
638+
return nil
639+
}
640+
632641
func (s *ScalarKernel) Exec(ctx *KernelCtx, sp *ExecSpan, out *ExecResult) error {
633642
return s.ExecFn(ctx, sp, out)
634643
}
@@ -693,3 +702,4 @@ func (s *VectorKernel) Exec(ctx *KernelCtx, sp *ExecSpan, out *ExecResult) error
693702
func (s VectorKernel) GetNullHandling() NullHandling { return s.NullHandling }
694703
func (s VectorKernel) GetMemAlloc() MemAlloc { return s.MemAlloc }
695704
func (s VectorKernel) CanFillSlices() bool { return s.CanWriteIntoSlices }
705+
func (s VectorKernel) Cleanup() error { return nil }

arrow/compute/executor.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package compute
2020

2121
import (
2222
"context"
23+
"errors"
2324
"fmt"
2425
"math"
2526
"runtime"
@@ -579,6 +580,10 @@ func (s *scalarExecutor) WrapResults(ctx context.Context, out <-chan Datum, hasC
579580
}
580581

581582
func (s *scalarExecutor) executeSpans(data chan<- Datum) (err error) {
583+
defer func() {
584+
err = errors.Join(err, s.kernel.Cleanup())
585+
}()
586+
582587
var (
583588
input exec.ExecSpan
584589
output exec.ExecResult
@@ -645,7 +650,7 @@ func (s *scalarExecutor) executeSingleSpan(input *exec.ExecSpan, out *exec.ExecR
645650
return s.kernel.Exec(s.ctx, input, out)
646651
}
647652

648-
func (s *scalarExecutor) setupPrealloc(totalLen int64, args []Datum) error {
653+
func (s *scalarExecutor) setupPrealloc(_ int64, args []Datum) error {
649654
s.numOutBuf = len(s.outType.Layout().Buffers)
650655
outTypeID := s.outType.ID()
651656
// default to no validity pre-allocation for the following cases:

arrow/compute/expression.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ func Cast(ex Expression, dt arrow.DataType) Expression {
490490
return NewCall("cast", []Expression{ex}, opts)
491491
}
492492

493+
// Deprecated: Use SetOptions instead
493494
type SetLookupOptions struct {
494495
ValueSet Datum `compute:"value_set"`
495496
SkipNulls bool `compute:"skip_nulls"`

arrow/compute/exprs/exec.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,6 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
524524
err error
525525
allScalar = true
526526
args = make([]compute.Datum, e.NArgs())
527-
argTypes = make([]arrow.DataType, e.NArgs())
528527
)
529528
for i := 0; i < e.NArgs(); i++ {
530529
switch v := e.Arg(i).(type) {
@@ -543,20 +542,23 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
543542
default:
544543
return nil, arrow.ErrNotImplemented
545544
}
546-
547-
argTypes[i] = args[i].(compute.ArrayLikeDatum).Type()
548545
}
549546

550547
_, conv, ok := ext.DecodeFunction(e.FuncRef())
551548
if !ok {
552-
return nil, arrow.ErrNotImplemented
549+
return nil, fmt.Errorf("%w: %s", arrow.ErrNotImplemented, e.Name())
553550
}
554551

555-
fname, opts, err := conv(e)
552+
fname, args, opts, err := conv(e, args)
556553
if err != nil {
557554
return nil, err
558555
}
559556

557+
argTypes := make([]arrow.DataType, len(args))
558+
for i, arg := range args {
559+
argTypes[i] = arg.(compute.ArrayLikeDatum).Type()
560+
}
561+
560562
ectx := compute.GetExecCtx(ctx)
561563
fn, ok := ectx.Registry.GetFunction(fname)
562564
if !ok {

arrow/compute/exprs/extension_types.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626

2727
"github.com/apache/arrow-go/v18/arrow"
2828
"github.com/apache/arrow-go/v18/arrow/array"
29+
"github.com/apache/arrow-go/v18/arrow/extensions"
2930
)
3031

3132
type simpleExtensionTypeFactory[P comparable] struct {
@@ -95,13 +96,6 @@ type simpleExtensionArrayFactory[P comparable] struct {
9596
array.ExtensionArrayBase
9697
}
9798

98-
type uuidExtParams struct{}
99-
100-
var uuidType = simpleExtensionTypeFactory[uuidExtParams]{
101-
name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType {
102-
return &arrow.FixedSizeBinaryType{ByteWidth: 16}
103-
}}
104-
10599
type fixedCharExtensionParams struct {
106100
Length int32 `json:"length"`
107101
}
@@ -138,7 +132,7 @@ var intervalDayType = simpleExtensionTypeFactory[intervalDayExtensionParams]{
138132
},
139133
}
140134

141-
func uuid() arrow.DataType { return uuidType.CreateType(uuidExtParams{}) }
135+
func uuid() arrow.DataType { return extensions.NewUUIDType() }
142136
func fixedChar(length int32) arrow.DataType {
143137
return fixedCharType.CreateType(fixedCharExtensionParams{Length: length})
144138
}

arrow/compute/exprs/types.go

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626

2727
"github.com/apache/arrow-go/v18/arrow"
2828
"github.com/apache/arrow-go/v18/arrow/compute"
29+
"github.com/apache/arrow-go/v18/arrow/scalar"
2930
"github.com/substrait-io/substrait-go/v3/expr"
3031
"github.com/substrait-io/substrait-go/v3/extensions"
3132
"github.com/substrait-io/substrait-go/v3/types"
@@ -41,7 +42,8 @@ const (
4142
SubstraitComparisonFuncsURI = SubstraitDefaultURIPrefix + "functions_comparison.yaml"
4243
SubstraitBooleanFuncsURI = SubstraitDefaultURIPrefix + "functions_boolean.yaml"
4344

44-
TimestampTzTimezone = "UTC"
45+
SubstraitIcebergSetFuncURI = "https://github.com/apache/iceberg-go/blob/main/table/substrait/functions_set.yaml"
46+
TimestampTzTimezone = "UTC"
4547
)
4648

4749
var hashSeed maphash.Seed
@@ -127,6 +129,15 @@ func init() {
127129
panic(err)
128130
}
129131
}
132+
133+
for _, fn := range []string{"is_in"} {
134+
err := DefaultExtensionIDRegistry.AddSubstraitScalarToArrow(
135+
extensions.ID{URI: SubstraitIcebergSetFuncURI, Name: fn},
136+
setLookupFuncSubstraitToArrowFunc)
137+
if err != nil {
138+
panic(err)
139+
}
140+
}
130141
}
131142

132143
type overflowBehavior string
@@ -178,7 +189,7 @@ func parseOption[typ ~string](sf *expr.ScalarFunction, optionName string, parser
178189
return def, arrow.ErrNotImplemented
179190
}
180191

181-
type substraitToArrow = func(*expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error)
192+
type substraitToArrow = func(*expr.ScalarFunction, []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error)
182193
type arrowToSubstrait = func(fname string) (extensions.ID, []*types.FunctionOption, error)
183194

184195
var substraitToArrowFuncMap = map[string]string{
@@ -199,7 +210,32 @@ var arrowToSubstraitFuncMap = map[string]string{
199210
"or_kleene": "or",
200211
}
201212

202-
func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) {
213+
func setLookupFuncSubstraitToArrowFunc(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) {
214+
fname, _, _ = strings.Cut(sf.Name(), ":")
215+
f, ok := substraitToArrowFuncMap[fname]
216+
if ok {
217+
fname = f
218+
}
219+
220+
setopts := &compute.SetOptions{
221+
NullBehavior: compute.NullMatchingMatch,
222+
}
223+
switch input[1].Kind() {
224+
case compute.KindArray, compute.KindChunked:
225+
setopts.ValueSet = input[1]
226+
case compute.KindScalar:
227+
// should be a list scalar
228+
setopts.ValueSet = compute.NewDatumWithoutOwning(
229+
input[1].(*compute.ScalarDatum).Value.(*scalar.List).Value)
230+
}
231+
232+
args, opts = input[0:1], setopts
233+
return
234+
}
235+
236+
func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) {
237+
args = input
238+
203239
fname, _, _ = strings.Cut(sf.Name(), ":")
204240
f, ok := substraitToArrowFuncMap[fname]
205241
if ok {
@@ -219,19 +255,19 @@ func simpleMapArrowToSubstraitFunc(uri string) arrowToSubstrait {
219255
}
220256

221257
func decodeOptionlessOverflowableArithmetic(n string) substraitToArrow {
222-
return func(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) {
258+
return func(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) {
223259
overflow, err := parseOption(sf, "overflow", &overflowParser, []overflowBehavior{overflowSILENT, overflowERROR}, overflowSILENT)
224260
if err != nil {
225-
return n, nil, err
261+
return n, input, nil, err
226262
}
227263

228264
switch overflow {
229265
case overflowSILENT:
230-
return n + "_unchecked", nil, nil
266+
return n + "_unchecked", input, nil, nil
231267
case overflowERROR:
232-
return n, nil, nil
268+
return n, input, nil, nil
233269
default:
234-
return n, nil, arrow.ErrNotImplemented
270+
return n, input, nil, arrow.ErrNotImplemented
235271
}
236272
}
237273
}

0 commit comments

Comments
 (0)