diff --git a/native-engine/datafusion-ext-commons/src/arrow/boolean.rs b/native-engine/datafusion-ext-commons/src/arrow/boolean.rs new file mode 100644 index 000000000..d1407016e --- /dev/null +++ b/native-engine/datafusion-ext-commons/src/arrow/boolean.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::array::{Array, BooleanArray}; + +/// Returns a BooleanArray where nulls are converted to `false` and the result +/// has no null bitmap (all values are valid). +#[inline] +pub fn nulls_to_false(is_boolean: &BooleanArray) -> BooleanArray { + match is_boolean.nulls() { + Some(nulls) => { + let is_not_null = nulls.inner(); + BooleanArray::new(is_boolean.values() & is_not_null, None) + } + None => is_boolean.clone(), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, BooleanArray}; + + use super::nulls_to_false; + + #[test] + fn converts_nulls_to_false() { + let input = BooleanArray::from(vec![Some(true), None, Some(false)]); + let output = nulls_to_false(&input); + + assert!(output.nulls().is_none()); + + let got: Vec> = output.iter().collect(); + let expected = vec![Some(true), Some(false), Some(false)]; + assert_eq!(got, expected); + } + + #[test] + fn preserves_when_no_nulls() { + let input = BooleanArray::from(vec![Some(false), Some(true)]); + let output = nulls_to_false(&input); + + assert!(output.nulls().is_none()); + let got: Vec> = output.iter().collect(); + let expected = vec![Some(false), Some(true)]; + assert_eq!(got, expected); + } +} diff --git a/native-engine/datafusion-ext-commons/src/arrow/mod.rs b/native-engine/datafusion-ext-commons/src/arrow/mod.rs index afa9df902..b4e8c180f 100644 --- a/native-engine/datafusion-ext-commons/src/arrow/mod.rs +++ b/native-engine/datafusion-ext-commons/src/arrow/mod.rs @@ -14,6 +14,7 @@ // limitations under the License. pub mod array_size; +pub mod boolean; pub mod cast; pub mod coalesce; pub mod eq_comparator; diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index b99e406de..cad5198df 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -24,6 +24,7 @@ mod spark_crypto; mod spark_dates; pub mod spark_get_json_object; mod spark_hash; +mod spark_isnan; mod spark_make_array; mod spark_make_decimal; mod spark_normalize_nan_and_zero; @@ -75,6 +76,7 @@ pub fn create_auron_ext_function(name: &str) -> Result { Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero) } - _ => df_unimplemented_err!("auron ext function not implemented: {name}")?, + "Spark_IsNaN" => Arc::new(spark_isnan::spark_isnan), + _ => df_unimplemented_err!("spark ext function not implemented: {name}")?, }) } diff --git a/native-engine/datafusion-ext-functions/src/spark_isnan.rs b/native-engine/datafusion-ext-functions/src/spark_isnan.rs new file mode 100644 index 000000000..3de97b5c9 --- /dev/null +++ b/native-engine/datafusion-ext-functions/src/spark_isnan.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::{Array, BooleanArray, Float32Array, Float64Array}, + datatypes::DataType, +}; +use datafusion::{ + common::{Result, ScalarValue}, + logical_expr::ColumnarValue, +}; +use datafusion_ext_commons::arrow::boolean::nulls_to_false; + +pub fn spark_isnan(args: &[ColumnarValue]) -> Result { + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + let cleaned = nulls_to_false(&is_nan); + Ok(ColumnarValue::Array(Arc::new(cleaned))) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + let cleaned = nulls_to_false(&is_nan); + Ok(ColumnarValue::Array(Arc::new(cleaned))) + } + _other => { + // For non-float arrays, Spark's isnan is effectively false. + let len = array.len(); + let out = ScalarValue::Boolean(Some(false)).to_array_of_size(len)?; + Ok(ColumnarValue::Array(out)) + } + }, + ColumnarValue::Scalar(sv) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + match sv { + ScalarValue::Float64(a) => a.map(|x| x.is_nan()).unwrap_or(false), + ScalarValue::Float32(a) => a.map(|x| x.is_nan()).unwrap_or(false), + _ => false, + }, + )))), + } +} + +#[cfg(test)] +mod test { + use std::{error::Error, sync::Arc}; + + use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; + use datafusion::{common::ScalarValue, logical_expr::ColumnarValue}; + + use crate::spark_isnan::spark_isnan; + + #[test] + fn test_isnan_array_f64() -> Result<(), Box> { + let input_data = vec![ + Some(12345678.0), + Some(f64::NAN), + Some(-0.0), + None, + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + ]; + let input_columnar_value = ColumnarValue::Array(Arc::new(Float64Array::from(input_data))); + + let result = spark_isnan(&vec![input_columnar_value])?.into_array(6)?; + + let expected_data = vec![ + Some(false), + Some(true), + Some(false), + Some(false), // null returns false in Spark + Some(false), + Some(false), + ]; + let expected: ArrayRef = Arc::new(BooleanArray::from(expected_data)); + assert_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn test_isnan_array_f32() -> Result<(), Box> { + let input_data = vec![ + Some(12345678.0f32), + Some(f32::NAN), + Some(-0.0f32), + None, + Some(f32::INFINITY), + Some(f32::NEG_INFINITY), + ]; + let input_columnar_value = ColumnarValue::Array(Arc::new(Float32Array::from(input_data))); + + let result = spark_isnan(&vec![input_columnar_value])?.into_array(6)?; + + let expected_data = vec![ + Some(false), + Some(true), + Some(false), + Some(false), // null returns false in Spark + Some(false), + Some(false), + ]; + let expected: ArrayRef = Arc::new(BooleanArray::from(expected_data)); + assert_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn test_isnan_scalar_f64_nan() -> Result<(), Box> { + let input_columnar_value = ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN))); + let result = spark_isnan(&vec![input_columnar_value])?.into_array(1)?; + let expected: ArrayRef = Arc::new(BooleanArray::from(vec![Some(true)])); + assert_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn test_isnan_scalar_f64_null() -> Result<(), Box> { + let input_columnar_value = ColumnarValue::Scalar(ScalarValue::Float64(None)); + let result = spark_isnan(&vec![input_columnar_value])?.into_array(1)?; + let expected: ArrayRef = Arc::new(BooleanArray::from(vec![Some(false)])); + assert_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn test_isnan_scalar_f32_null() -> Result<(), Box> { + let input_columnar_value = ColumnarValue::Scalar(ScalarValue::Float32(None)); + let result = spark_isnan(&vec![input_columnar_value])?.into_array(1)?; + let expected: ArrayRef = Arc::new(BooleanArray::from(vec![Some(false)])); + assert_eq!(&result, &expected); + Ok(()) + } +} diff --git a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronFunctionSuite.scala b/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronFunctionSuite.scala index 07725e80c..78bddbba9 100644 --- a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronFunctionSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronFunctionSuite.scala @@ -381,27 +381,31 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite { } } - ignore("DISABLED: isNaN native semantics mismatch (null -> false)") { - /* TODO: enable once Spark-compatible isNaN lands https://github.com/apache/auron/issues/1646 */ - - test("test function IsNaN") { - withTable("t1") { - sql( - "create table test_is_nan using parquet as select cast('NaN' as double) as c1, cast('NaN' as float) as c2, log(-3) as c3, cast(null as double) as c4, 5.5f as c5") - val functions = - """ - |select - | isnan(c1), - | isnan(c2), - | isnan(c3), - | isnan(c4), - | isnan(c5) - |from - | test_is_nan + test("test function IsNaN") { + withTable("t1") { + sql(""" + |create table test_is_nan using parquet as select + | cast('NaN' as double) as c1, + | cast('NaN' as float) as c2, + | cast(null as double) as c3, + | cast(null as double) as c4, + | cast(5.5 as float) as c5, + | cast(null as float) as c6 + |""".stripMargin) + val functions = + """ + |select + | isnan(c1), + | isnan(c2), + | isnan(c3), + | isnan(c4), + | isnan(c5), + | isnan(c6) + |from + | test_is_nan """.stripMargin - checkSparkAnswerAndOperator(functions) - } + checkSparkAnswerAndOperator(functions) } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 11cdcc4b2..65d6040d2 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -844,7 +844,7 @@ object NativeConverters extends Logging { buildScalarFunction(pb.ScalarFunction.Factorial, e.children, e.dataType) case e: Hex => buildScalarFunction(pb.ScalarFunction.Hex, e.children, e.dataType) case e: IsNaN => - buildScalarFunction(pb.ScalarFunction.IsNaN, e.children, e.dataType) + buildExtScalarFunction("Spark_IsNaN", e.children, e.dataType) case e: Round => e.scale match { case Literal(n: Int, _) =>