Skip to content

Commit 4ff615f

Browse files
authored
fix(arrow/compute): Fix scalar comparison batches (#465)
### Rationale for this change Fixes #464 ### What changes are included in this PR? Ensure the size of the slice used in scalar comparisons stays within the expected batch size ### Are these changes tested? Yes a test is added. ### Are there any user-facing changes? A condition that previously resulted in panic will now succeed
1 parent 4c2983c commit 4ff615f

2 files changed

Lines changed: 83 additions & 1 deletion

File tree

arrow/compute/exprs/exec_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,85 @@ func TestLargeTypes(t *testing.T) {
699699
defer result.Release()
700700
})
701701
}
702+
703+
func TestDecimalFilterLarge(t *testing.T) {
704+
t.Parallel()
705+
706+
tt := []struct {
707+
name string
708+
n int
709+
}{
710+
{
711+
name: "arrow.DECIMAL128 - number of records < 33 ok",
712+
n: 32,
713+
},
714+
{
715+
name: "arrow.DECIMAL128 - number of records >= 33 panic",
716+
n: 33,
717+
},
718+
}
719+
720+
for _, tc := range tt {
721+
tc := tc
722+
t.Run(tc.name, func(t *testing.T) {
723+
t.Parallel()
724+
725+
ctx := context.Background()
726+
rq := require.New(t)
727+
728+
typ := &arrow.Decimal128Type{Precision: 3, Scale: 1}
729+
field := arrow.Field{
730+
Name: "col",
731+
Type: typ,
732+
Nullable: true,
733+
}
734+
schema := arrow.NewSchema([]arrow.Field{field}, nil)
735+
736+
db := array.NewDecimal128Builder(memory.DefaultAllocator, typ)
737+
defer db.Release()
738+
739+
for i := 0; i < tc.n; i++ {
740+
d, err := decimal.Decimal128FromFloat(float64(i), 3, 1)
741+
rq.NoError(err, "Failed to create Decimal128 value")
742+
743+
db.Append(d)
744+
}
745+
746+
rec := array.NewRecord(schema, []arrow.Array{db.NewArray()}, int64(tc.n))
747+
748+
extSet := exprs.GetExtensionIDSet(ctx)
749+
builder := exprs.NewExprBuilder(extSet)
750+
751+
err := builder.SetInputSchema(schema)
752+
rq.NoError(err, "Failed to set input schema")
753+
754+
v, p, s, err := expr.DecimalStringToBytes("10.0")
755+
rq.NoError(err, "Failed to convert decimal string to bytes")
756+
757+
lit, err := expr.NewLiteral(&types.Decimal{
758+
Value: v[:16],
759+
Precision: p,
760+
Scale: s,
761+
}, true)
762+
rq.NoError(err, "Failed to create Decimal128 literal")
763+
764+
b, err := builder.CallScalar("less", nil,
765+
builder.FieldRef("col"),
766+
builder.Literal(lit),
767+
)
768+
769+
rq.NoError(err, "Failed to call scalar")
770+
771+
e, err := b.BuildExpr()
772+
rq.NoError(err, "Failed to build expression")
773+
774+
ctx = exprs.WithExtensionIDSet(ctx, extSet)
775+
776+
dr := compute.NewDatum(rec)
777+
defer dr.Release()
778+
779+
_, err = exprs.ExecuteScalarExpression(ctx, schema, e, dr)
780+
rq.NoError(err, "Failed to execute scalar expression")
781+
})
782+
}
783+
}

arrow/compute/internal/kernels/scalar_comparisons.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func comparePrimitiveScalarArray[T arrow.FixedWidthType](op cmpScalarLeft[T, T])
147147
}
148148

149149
for j := 0; j < nbatches; j++ {
150-
op(leftVal, right, tmpOutSlice)
150+
op(leftVal, right[:batchSize], tmpOutSlice)
151151
right = right[batchSize:]
152152
packBits(tmpOutput, out)
153153
out = out[batchSize/8:]

0 commit comments

Comments
 (0)