Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion arrow/compute/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ func (c *CastSuite) TestToIntDowncastUnsafe() {
}

func (c *CastSuite) TestFloatingToInt() {
for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64} {
for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64, arrow.FixedWidthTypes.Float16} {
for _, to := range []arrow.DataType{arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int64} {
// float to int no truncation
c.checkCast(from, to, `[1.0, null, 0.0, -1.0, 5.0]`, `[1, null, 0, -1, 5]`)
Expand All @@ -590,6 +590,12 @@ func (c *CastSuite) TestFloatingToInt() {
}
}

func (c *CastSuite) TestFloat16ToFloating() {
for _, to := range []arrow.DataType{arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64} {
c.checkCast(arrow.FixedWidthTypes.Float16, to, `[1.5, null, 0.0, -1.5, 5.5]`, `[1.5, null, 0.0, -1.5, 5.5]`)
}
}

func (c *CastSuite) TestIntToFloating() {
for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Int32} {
two24 := `[16777216, 16777217]`
Expand Down
46 changes: 45 additions & 1 deletion arrow/compute/internal/kernels/cast_numeric.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"unsafe"

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/float16"
)

var castNumericUnsafe func(itype, otype arrow.Type, in, out []byte, len int) = castNumericGo
Expand All @@ -32,7 +33,19 @@ func DoStaticCast[InT, OutT numeric](in []InT, out []OutT) {
}
}

func reinterpret[T numeric](b []byte, len int) (res []T) {
func DoFloat16Cast[InT numeric](in []InT, out []float16.Num) {
for i, v := range in {
out[i] = float16.New(float32(v))
}
}

func DoFloat16CastToNumber[OutT numeric](in []float16.Num, out []OutT) {
for i, v := range in {
out[i] = OutT(v.Float32())
}
}

func reinterpret[T numeric | float16.Num](b []byte, len int) (res []T) {
return unsafe.Slice((*T)(unsafe.Pointer(&b[0])), len)
}

Expand All @@ -54,13 +67,42 @@ func castNumberToNumberUnsafeImpl[T numeric](outT arrow.Type, in []T, out []byte
DoStaticCast(in, reinterpret[int64](out, len(in)))
case arrow.UINT64:
DoStaticCast(in, reinterpret[uint64](out, len(in)))
case arrow.FLOAT16:
DoFloat16Cast(in, reinterpret[float16.Num](out, len(in)))
case arrow.FLOAT32:
DoStaticCast(in, reinterpret[float32](out, len(in)))
case arrow.FLOAT64:
DoStaticCast(in, reinterpret[float64](out, len(in)))
}
}

func castFloat16ToNumberUnsafeImpl(outT arrow.Type, in []float16.Num, out []byte) {
switch outT {
case arrow.INT8:
DoFloat16CastToNumber(in, reinterpret[int8](out, len(in)))
case arrow.UINT8:
DoFloat16CastToNumber(in, reinterpret[uint8](out, len(in)))
case arrow.INT16:
DoFloat16CastToNumber(in, reinterpret[int16](out, len(in)))
case arrow.UINT16:
DoFloat16CastToNumber(in, reinterpret[uint16](out, len(in)))
case arrow.INT32:
DoFloat16CastToNumber(in, reinterpret[int32](out, len(in)))
case arrow.UINT32:
DoFloat16CastToNumber(in, reinterpret[uint32](out, len(in)))
case arrow.INT64:
DoFloat16CastToNumber(in, reinterpret[int64](out, len(in)))
case arrow.UINT64:
DoFloat16CastToNumber(in, reinterpret[uint64](out, len(in)))
case arrow.FLOAT16:
copy(reinterpret[float16.Num](out, len(in)), in)
case arrow.FLOAT32:
DoFloat16CastToNumber(in, reinterpret[float32](out, len(in)))
case arrow.FLOAT64:
DoFloat16CastToNumber(in, reinterpret[float64](out, len(in)))
}
}

func castNumericGo(itype, otype arrow.Type, in, out []byte, len int) {
switch itype {
case arrow.INT8:
Expand All @@ -79,6 +121,8 @@ func castNumericGo(itype, otype arrow.Type, in, out []byte, len int) {
castNumberToNumberUnsafeImpl(otype, reinterpret[int64](in, len), out)
case arrow.UINT64:
castNumberToNumberUnsafeImpl(otype, reinterpret[uint64](in, len), out)
case arrow.FLOAT16:
castFloat16ToNumberUnsafeImpl(otype, reinterpret[float16.Num](in, len), out)
case arrow.FLOAT32:
castNumberToNumberUnsafeImpl(otype, reinterpret[float32](in, len), out)
case arrow.FLOAT64:
Expand Down
6 changes: 5 additions & 1 deletion arrow/compute/internal/kernels/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,11 @@ func castNumberToNumberUnsafe(in, out *exec.ArraySpan) {

inputOffset := in.Type.(arrow.FixedWidthDataType).Bytes() * int(in.Offset)
outputOffset := out.Type.(arrow.FixedWidthDataType).Bytes() * int(out.Offset)
castNumericUnsafe(in.Type.ID(), out.Type.ID(), in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
if in.Type.ID() == arrow.FLOAT16 || out.Type.ID() == arrow.FLOAT16 {
castNumericGo(in.Type.ID(), out.Type.ID(), in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
} else {
castNumericUnsafe(in.Type.ID(), out.Type.ID(), in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
}
}

func MaxDecimalDigitsForInt(id arrow.Type) (int32, error) {
Expand Down
133 changes: 117 additions & 16 deletions arrow/compute/internal/kernels/numeric_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/decimal256"
"github.com/apache/arrow-go/v18/arrow/float16"
"github.com/apache/arrow-go/v18/arrow/internal/debug"
"github.com/apache/arrow-go/v18/internal/bitutils"
"golang.org/x/exp/constraints"
Expand Down Expand Up @@ -506,6 +507,27 @@ func CastFloat64ToDecimal(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.E
return executor(ctx, batch, out)
}

func CastDecimalToFloat16(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
var (
executor exec.ArrayKernelExec
)

switch dt := batch.Values[0].Array.Type.(type) {
case *arrow.Decimal128Type:
scale := dt.Scale
executor = ScalarUnaryNotNull(func(_ *exec.KernelCtx, v decimal128.Num, err *error) float16.Num {
return float16.New(v.ToFloat32(scale))
})
case *arrow.Decimal256Type:
scale := dt.Scale
executor = ScalarUnaryNotNull(func(_ *exec.KernelCtx, v decimal256.Num, err *error) float16.Num {
return float16.New(v.ToFloat32(scale))
})
}

return executor(ctx, batch, out)
}

func CastDecimalToFloating[OutT constraints.Float](ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
var (
executor exec.ArrayKernelExec
Expand Down Expand Up @@ -543,13 +565,49 @@ func boolToNum[T numeric](_ *exec.KernelCtx, in []byte, out []T) error {
return nil
}

func checkFloatTrunc[InT constraints.Float, OutT arrow.IntType | arrow.UintType](in, out *exec.ArraySpan) error {
wasTrunc := func(out OutT, in InT) bool {
return InT(out) != in
func boolToFloat16(_ *exec.KernelCtx, in []byte, out []float16.Num) error {
var (
zero float16.Num
one = float16.New(1)
)

for i := range out {
if bitutil.BitIsSet(in, i) {
out[i] = one
} else {
out[i] = zero
}
}
wasTruncMaybeNull := func(out OutT, in InT, isValid bool) bool {
return isValid && (InT(out) != in)
return nil
}

func wasTrunc[InT constraints.Float | float16.Num, OutT arrow.IntType | arrow.UintType](out OutT, in InT) bool {
switch v := any(in).(type) {
case float16.Num:
return float16.New(float32(out)) != v
case float32:
return float32(out) != v
case float64:
return float64(out) != v
default:
return false
}
}

func wasTruncMaybeNull[InT constraints.Float | float16.Num, OutT arrow.IntType | arrow.UintType](out OutT, in InT, isValid bool) bool {
switch v := any(in).(type) {
case float16.Num:
return isValid && (float16.New(float32(out)) != v)
case float32:
return isValid && (float32(out) != v)
case float64:
return isValid && (float64(out) != v)
default:
return false
}
}

func checkFloatTrunc[InT constraints.Float | float16.Num, OutT arrow.IntType | arrow.UintType](in, out *exec.ArraySpan) error {
getError := func(val InT) error {
return fmt.Errorf("%w: float value %f was truncated converting to %s",
arrow.ErrInvalid, val, out.Type)
Expand Down Expand Up @@ -598,7 +656,7 @@ func checkFloatTrunc[InT constraints.Float, OutT arrow.IntType | arrow.UintType]
return nil
}

func checkFloatToIntTruncImpl[T constraints.Float](in, out *exec.ArraySpan) error {
func checkFloatToIntTruncImpl[T constraints.Float | float16.Num](in, out *exec.ArraySpan) error {
switch out.Type.ID() {
case arrow.INT8:
return checkFloatTrunc[T, int8](in, out)
Expand All @@ -623,6 +681,8 @@ func checkFloatToIntTruncImpl[T constraints.Float](in, out *exec.ArraySpan) erro

func checkFloatToIntTrunc(in, out *exec.ArraySpan) error {
switch in.Type.ID() {
case arrow.FLOAT16:
return checkFloatToIntTruncImpl[float16.Num](in, out)
case arrow.FLOAT32:
return checkFloatToIntTruncImpl[float32](in, out)
case arrow.FLOAT64:
Expand Down Expand Up @@ -729,6 +789,26 @@ func getParseStringExec[OffsetT int32 | int64](out arrow.Type) exec.ArrayKernelE
panic("invalid type for getParseStringExec")
}

func addFloat16Casts(outTy arrow.DataType, kernels []exec.ScalarKernel) []exec.ScalarKernel {
kernels = append(kernels, GetCommonCastKernels(outTy.ID(), exec.NewOutputType(outTy))...)

kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(arrow.FixedWidthTypes.Boolean)},
exec.NewOutputType(outTy), ScalarUnaryBoolArg(boolToFloat16), nil))

for _, inTy := range []arrow.DataType{arrow.BinaryTypes.Binary, arrow.BinaryTypes.String} {
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(inTy)}, exec.NewOutputType(outTy),
getParseStringExec[int32](outTy.ID()), nil))
}
for _, inTy := range []arrow.DataType{arrow.BinaryTypes.LargeBinary, arrow.BinaryTypes.LargeString} {
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(inTy)}, exec.NewOutputType(outTy),
getParseStringExec[int64](outTy.ID()), nil))
}
return kernels
}

func addCommonNumberCasts[T numeric](outTy arrow.DataType, kernels []exec.ScalarKernel) []exec.ScalarKernel {
kernels = append(kernels, GetCommonCastKernels(outTy.ID(), exec.NewOutputType(outTy))...)

Expand Down Expand Up @@ -759,7 +839,7 @@ func GetCastToInteger[T arrow.IntType | arrow.UintType](outType arrow.DataType)
CastIntToInt, nil))
}

for _, inTy := range floatingTypes {
for _, inTy := range append(floatingTypes, arrow.FixedWidthTypes.Float16) {
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(inTy)}, output,
CastFloatingToInteger, nil))
Expand All @@ -775,7 +855,7 @@ func GetCastToInteger[T arrow.IntType | arrow.UintType](outType arrow.DataType)
return kernels
}

func GetCastToFloating[T constraints.Float](outType arrow.DataType) []exec.ScalarKernel {
func GetCastToFloating[T constraints.Float | float16.Num](outType arrow.DataType) []exec.ScalarKernel {
kernels := make([]exec.ScalarKernel, 0)

output := exec.NewOutputType(outType)
Expand All @@ -785,19 +865,40 @@ func GetCastToFloating[T constraints.Float](outType arrow.DataType) []exec.Scala
CastIntegerToFloating, nil))
}

for _, inTy := range floatingTypes {
for _, inTy := range append(floatingTypes, arrow.FixedWidthTypes.Float16) {
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(inTy)}, output,
CastFloatingToFloating, nil))
}

kernels = addCommonNumberCasts[T](outType, kernels)
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output,
CastDecimalToFloating[T], nil))
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output,
CastDecimalToFloating[T], nil))
var z T
switch any(z).(type) {
case float16.Num:
kernels = addFloat16Casts(outType, kernels)
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output,
CastDecimalToFloat16, nil))
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output,
CastDecimalToFloat16, nil))
case float32:
kernels = addCommonNumberCasts[float32](outType, kernels)
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output,
CastDecimalToFloating[float32], nil))
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output,
CastDecimalToFloating[float32], nil))
case float64:
kernels = addCommonNumberCasts[float64](outType, kernels)
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output,
CastDecimalToFloating[float64], nil))
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output,
CastDecimalToFloating[float64], nil))
}

return kernels
}

Expand Down
7 changes: 6 additions & 1 deletion arrow/float16/float16.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package float16

import (
"encoding/binary"
"fmt"
"math"
"strconv"
)
Expand Down Expand Up @@ -58,6 +59,10 @@ func New(f float32) Num {
return Num{bits: (sn << 15) | uint16(res<<10) | fc}
}

func (f Num) Format(s fmt.State, verb rune) {
fmt.Fprintf(s, fmt.FormatString(s, verb), f.Float32())
}

func (f Num) Float32() float32 {
sn := uint32((f.bits >> 15) & 0x1)
exp := (f.bits >> 10) & 0x1f
Expand Down Expand Up @@ -179,7 +184,7 @@ func (n Num) IsInf() bool { return (n.bits & 0x7c00) == 0x7c00 }

func (n Num) IsZero() bool { return (n.bits & 0x7fff) == 0 }

func (f Num) Uint16() uint16 { return f.bits }
func (f Num) Uint16() uint16 { return uint16(f.bits) }
func (f Num) String() string { return strconv.FormatFloat(float64(f.Float32()), 'g', -1, 32) }

func Inf() Num { return Num{bits: 0x7c00} }
Expand Down
2 changes: 1 addition & 1 deletion arrow/float16/float16_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestFloat16(t *testing.T) {
f := k.Float32()
assert.Equal(t, v, f, "float32 values should be the same")
i := New(v)
assert.Equal(t, k.bits, i.bits, "float16 values should be the same")
assert.Equal(t, k, i, "float16 values should be the same")
assert.Equal(t, k.Uint16(), i.Uint16(), "float16 values should be the same")
assert.Equal(t, k.String(), fmt.Sprintf("%v", v), "string representation differ")
}
Expand Down
Loading