Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions arrow/array/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,13 @@ func (b *builder) resize(newBits int, init func(int)) {
}

func (b *builder) reserve(elements int, resize func(int)) {
if b.nullBitmap == nil {
b.nullBitmap = memory.NewResizableBuffer(b.mem)
}
if b.length+elements > b.capacity {
newCap := bitutil.NextPowerOf2(b.length + elements)
resize(newCap)
}
if b.nullBitmap == nil {
b.nullBitmap = memory.NewResizableBuffer(b.mem)
}
}

// unsafeAppendBoolsToBitmap appends the contents of valid to the validity bitmap.
Expand Down
12 changes: 0 additions & 12 deletions arrow/array/float16.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,6 @@ func (a *Float16) MarshalJSON() ([]byte, error) {
return json.Marshal(vals)
}

func arrayEqualFloat16(left, right *Float16) bool {
for i := 0; i < left.Len(); i++ {
if left.IsNull(i) {
continue
}
if left.Value(i) != right.Value(i) {
return false
}
}
return true
}

var (
_ arrow.Array = (*Float16)(nil)
_ arrow.TypedArray[float16.Num] = (*Float16)(nil)
Expand Down
2 changes: 1 addition & 1 deletion arrow/compute/arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ type BinaryArithmeticSuite[T arrow.NumericType] struct {
scalarEqualOpts []scalar.EqualOption
}

func (BinaryArithmeticSuite[T]) DataType() arrow.DataType {
func (*BinaryArithmeticSuite[T]) DataType() arrow.DataType {
return arrow.GetDataType[T]()
}

Expand Down
10 changes: 10 additions & 0 deletions arrow/compute/exec/kernel.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type NonAggKernel interface {
GetNullHandling() NullHandling
GetMemAlloc() MemAlloc
CanFillSlices() bool
Cleanup() error
}

// KernelCtx is a small struct holding the context for a kernel execution
Expand Down Expand Up @@ -604,6 +605,7 @@ type ScalarKernel struct {
CanWriteIntoSlices bool
NullHandling NullHandling
MemAlloc MemAlloc
CleanupFn func(KernelState) error
}

// NewScalarKernel constructs a new kernel for scalar execution, constructing
Expand All @@ -629,6 +631,13 @@ func NewScalarKernelWithSig(sig *KernelSignature, exec ArrayKernelExec, init Ker
}
}

func (s *ScalarKernel) Cleanup() error {
if s.CleanupFn != nil {
return s.CleanupFn(s.Data)
}
return nil
}

func (s *ScalarKernel) Exec(ctx *KernelCtx, sp *ExecSpan, out *ExecResult) error {
return s.ExecFn(ctx, sp, out)
}
Expand Down Expand Up @@ -693,3 +702,4 @@ func (s *VectorKernel) Exec(ctx *KernelCtx, sp *ExecSpan, out *ExecResult) error
func (s VectorKernel) GetNullHandling() NullHandling { return s.NullHandling }
func (s VectorKernel) GetMemAlloc() MemAlloc { return s.MemAlloc }
func (s VectorKernel) CanFillSlices() bool { return s.CanWriteIntoSlices }
func (s VectorKernel) Cleanup() error { return nil }
7 changes: 6 additions & 1 deletion arrow/compute/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package compute

import (
"context"
"errors"
"fmt"
"math"
"runtime"
Expand Down Expand Up @@ -579,6 +580,10 @@ func (s *scalarExecutor) WrapResults(ctx context.Context, out <-chan Datum, hasC
}

func (s *scalarExecutor) executeSpans(data chan<- Datum) (err error) {
defer func() {
err = errors.Join(err, s.kernel.Cleanup())
}()

var (
input exec.ExecSpan
output exec.ExecResult
Expand Down Expand Up @@ -645,7 +650,7 @@ func (s *scalarExecutor) executeSingleSpan(input *exec.ExecSpan, out *exec.ExecR
return s.kernel.Exec(s.ctx, input, out)
}

func (s *scalarExecutor) setupPrealloc(totalLen int64, args []Datum) error {
func (s *scalarExecutor) setupPrealloc(_ int64, args []Datum) error {
s.numOutBuf = len(s.outType.Layout().Buffers)
outTypeID := s.outType.ID()
// default to no validity pre-allocation for the following cases:
Expand Down
1 change: 1 addition & 0 deletions arrow/compute/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ func Cast(ex Expression, dt arrow.DataType) Expression {
return NewCall("cast", []Expression{ex}, opts)
}

// Deprecated: Use SetOptions instead
type SetLookupOptions struct {
ValueSet Datum `compute:"value_set"`
SkipNulls bool `compute:"skip_nulls"`
Expand Down
12 changes: 7 additions & 5 deletions arrow/compute/exprs/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
err error
allScalar = true
args = make([]compute.Datum, e.NArgs())
argTypes = make([]arrow.DataType, e.NArgs())
)
for i := 0; i < e.NArgs(); i++ {
switch v := e.Arg(i).(type) {
Expand All @@ -543,20 +542,23 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E
default:
return nil, arrow.ErrNotImplemented
}

argTypes[i] = args[i].(compute.ArrayLikeDatum).Type()
}

_, conv, ok := ext.DecodeFunction(e.FuncRef())
if !ok {
return nil, arrow.ErrNotImplemented
return nil, fmt.Errorf("%w: %s", arrow.ErrNotImplemented, e.Name())
}

fname, opts, err := conv(e)
fname, args, opts, err := conv(e, args)
if err != nil {
return nil, err
}

argTypes := make([]arrow.DataType, len(args))
for i, arg := range args {
argTypes[i] = arg.(compute.ArrayLikeDatum).Type()
}

ectx := compute.GetExecCtx(ctx)
fn, ok := ectx.Registry.GetFunction(fname)
if !ok {
Expand Down
10 changes: 2 additions & 8 deletions arrow/compute/exprs/extension_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/extensions"
)

type simpleExtensionTypeFactory[P comparable] struct {
Expand Down Expand Up @@ -95,13 +96,6 @@ type simpleExtensionArrayFactory[P comparable] struct {
array.ExtensionArrayBase
}

type uuidExtParams struct{}

var uuidType = simpleExtensionTypeFactory[uuidExtParams]{
name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType {
return &arrow.FixedSizeBinaryType{ByteWidth: 16}
}}

type fixedCharExtensionParams struct {
Length int32 `json:"length"`
}
Expand Down Expand Up @@ -138,7 +132,7 @@ var intervalDayType = simpleExtensionTypeFactory[intervalDayExtensionParams]{
},
}

func uuid() arrow.DataType { return uuidType.CreateType(uuidExtParams{}) }
func uuid() arrow.DataType { return extensions.NewUUIDType() }
func fixedChar(length int32) arrow.DataType {
return fixedCharType.CreateType(fixedCharExtensionParams{Length: length})
}
Expand Down
52 changes: 44 additions & 8 deletions arrow/compute/exprs/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/scalar"
"github.com/substrait-io/substrait-go/v3/expr"
"github.com/substrait-io/substrait-go/v3/extensions"
"github.com/substrait-io/substrait-go/v3/types"
Expand All @@ -41,7 +42,8 @@ const (
SubstraitComparisonFuncsURI = SubstraitDefaultURIPrefix + "functions_comparison.yaml"
SubstraitBooleanFuncsURI = SubstraitDefaultURIPrefix + "functions_boolean.yaml"

TimestampTzTimezone = "UTC"
SubstraitIcebergSetFuncURI = "https://github.com/apache/iceberg-go/blob/main/table/substrait/functions_set.yaml"
TimestampTzTimezone = "UTC"
)

var hashSeed maphash.Seed
Expand Down Expand Up @@ -127,6 +129,15 @@ func init() {
panic(err)
}
}

for _, fn := range []string{"is_in"} {
err := DefaultExtensionIDRegistry.AddSubstraitScalarToArrow(
extensions.ID{URI: SubstraitIcebergSetFuncURI, Name: fn},
setLookupFuncSubstraitToArrowFunc)
if err != nil {
panic(err)
}
}
}

type overflowBehavior string
Expand Down Expand Up @@ -178,7 +189,7 @@ func parseOption[typ ~string](sf *expr.ScalarFunction, optionName string, parser
return def, arrow.ErrNotImplemented
}

type substraitToArrow = func(*expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error)
type substraitToArrow = func(*expr.ScalarFunction, []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error)
type arrowToSubstrait = func(fname string) (extensions.ID, []*types.FunctionOption, error)

var substraitToArrowFuncMap = map[string]string{
Expand All @@ -199,7 +210,32 @@ var arrowToSubstraitFuncMap = map[string]string{
"or_kleene": "or",
}

func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) {
func setLookupFuncSubstraitToArrowFunc(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) {
fname, _, _ = strings.Cut(sf.Name(), ":")
f, ok := substraitToArrowFuncMap[fname]
if ok {
fname = f
}

setopts := &compute.SetOptions{
NullBehavior: compute.NullMatchingMatch,
}
switch input[1].Kind() {
case compute.KindArray, compute.KindChunked:
setopts.ValueSet = input[1]
case compute.KindScalar:
// should be a list scalar
setopts.ValueSet = compute.NewDatumWithoutOwning(
input[1].(*compute.ScalarDatum).Value.(*scalar.List).Value)
}

args, opts = input[0:1], setopts
return
}

func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) {
args = input

fname, _, _ = strings.Cut(sf.Name(), ":")
f, ok := substraitToArrowFuncMap[fname]
if ok {
Expand All @@ -219,19 +255,19 @@ func simpleMapArrowToSubstraitFunc(uri string) arrowToSubstrait {
}

func decodeOptionlessOverflowableArithmetic(n string) substraitToArrow {
return func(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) {
return func(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) {
overflow, err := parseOption(sf, "overflow", &overflowParser, []overflowBehavior{overflowSILENT, overflowERROR}, overflowSILENT)
if err != nil {
return n, nil, err
return n, input, nil, err
}

switch overflow {
case overflowSILENT:
return n + "_unchecked", nil, nil
return n + "_unchecked", input, nil, nil
case overflowERROR:
return n, nil, nil
return n, input, nil, nil
default:
return n, nil, arrow.ErrNotImplemented
return n, input, nil, arrow.ErrNotImplemented
}
}
}
Expand Down
Loading
Loading