@@ -26,6 +26,7 @@ import (
2626
2727 "github.com/apache/arrow-go/v18/arrow"
2828 "github.com/apache/arrow-go/v18/arrow/compute"
29+ "github.com/apache/arrow-go/v18/arrow/scalar"
2930 "github.com/substrait-io/substrait-go/v3/expr"
3031 "github.com/substrait-io/substrait-go/v3/extensions"
3132 "github.com/substrait-io/substrait-go/v3/types"
@@ -41,7 +42,8 @@ const (
4142 SubstraitComparisonFuncsURI = SubstraitDefaultURIPrefix + "functions_comparison.yaml"
4243 SubstraitBooleanFuncsURI = SubstraitDefaultURIPrefix + "functions_boolean.yaml"
4344
44- TimestampTzTimezone = "UTC"
45+ SubstraitIcebergSetFuncURI = "https://github.com/apache/iceberg-go/blob/main/table/substrait/functions_set.yaml"
46+ TimestampTzTimezone = "UTC"
4547)
4648
4749var hashSeed maphash.Seed
@@ -127,6 +129,15 @@ func init() {
127129 panic (err )
128130 }
129131 }
132+
133+ for _ , fn := range []string {"is_in" } {
134+ err := DefaultExtensionIDRegistry .AddSubstraitScalarToArrow (
135+ extensions.ID {URI : SubstraitIcebergSetFuncURI , Name : fn },
136+ setLookupFuncSubstraitToArrowFunc )
137+ if err != nil {
138+ panic (err )
139+ }
140+ }
130141}
131142
132143type overflowBehavior string
@@ -178,7 +189,7 @@ func parseOption[typ ~string](sf *expr.ScalarFunction, optionName string, parser
178189 return def , arrow .ErrNotImplemented
179190}
180191
181- type substraitToArrow = func (* expr.ScalarFunction ) (fname string , opts compute.FunctionOptions , err error )
192+ type substraitToArrow = func (* expr.ScalarFunction , []compute. Datum ) (fname string , args []compute. Datum , opts compute.FunctionOptions , err error )
182193type arrowToSubstrait = func (fname string ) (extensions.ID , []* types.FunctionOption , error )
183194
184195var substraitToArrowFuncMap = map [string ]string {
@@ -199,7 +210,32 @@ var arrowToSubstraitFuncMap = map[string]string{
199210 "or_kleene" : "or" ,
200211}
201212
202- func simpleMapSubstraitToArrowFunc (sf * expr.ScalarFunction ) (fname string , opts compute.FunctionOptions , err error ) {
213+ func setLookupFuncSubstraitToArrowFunc (sf * expr.ScalarFunction , input []compute.Datum ) (fname string , args []compute.Datum , opts compute.FunctionOptions , err error ) {
214+ fname , _ , _ = strings .Cut (sf .Name (), ":" )
215+ f , ok := substraitToArrowFuncMap [fname ]
216+ if ok {
217+ fname = f
218+ }
219+
220+ setopts := & compute.SetOptions {
221+ NullBehavior : compute .NullMatchingMatch ,
222+ }
223+ switch input [1 ].Kind () {
224+ case compute .KindArray , compute .KindChunked :
225+ setopts .ValueSet = input [1 ]
226+ case compute .KindScalar :
227+ // should be a list scalar
228+ setopts .ValueSet = compute .NewDatumWithoutOwning (
229+ input [1 ].(* compute.ScalarDatum ).Value .(* scalar.List ).Value )
230+ }
231+
232+ args , opts = input [0 :1 ], setopts
233+ return
234+ }
235+
236+ func simpleMapSubstraitToArrowFunc (sf * expr.ScalarFunction , input []compute.Datum ) (fname string , args []compute.Datum , opts compute.FunctionOptions , err error ) {
237+ args = input
238+
203239 fname , _ , _ = strings .Cut (sf .Name (), ":" )
204240 f , ok := substraitToArrowFuncMap [fname ]
205241 if ok {
@@ -219,19 +255,19 @@ func simpleMapArrowToSubstraitFunc(uri string) arrowToSubstrait {
219255}
220256
221257func decodeOptionlessOverflowableArithmetic (n string ) substraitToArrow {
222- return func (sf * expr.ScalarFunction ) (fname string , opts compute.FunctionOptions , err error ) {
258+ return func (sf * expr.ScalarFunction , input []compute. Datum ) (fname string , args []compute. Datum , opts compute.FunctionOptions , err error ) {
223259 overflow , err := parseOption (sf , "overflow" , & overflowParser , []overflowBehavior {overflowSILENT , overflowERROR }, overflowSILENT )
224260 if err != nil {
225- return n , nil , err
261+ return n , input , nil , err
226262 }
227263
228264 switch overflow {
229265 case overflowSILENT :
230- return n + "_unchecked" , nil , nil
266+ return n + "_unchecked" , input , nil , nil
231267 case overflowERROR :
232- return n , nil , nil
268+ return n , input , nil , nil
233269 default :
234- return n , nil , arrow .ErrNotImplemented
270+ return n , input , nil , arrow .ErrNotImplemented
235271 }
236272 }
237273}
0 commit comments