diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala new file mode 100644 index 000000000..cc62e5b68 --- /dev/null +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +class AuronInstrSuite extends QueryTest with SparkQueryTestsBase { + + test("test instr function - basic functionality") { + val data = Seq( + ("hello world", "world"), + ("hello world", "hello"), + ("hello world", "o"), + ("hello world", "z"), + (null, "test"), + ("test", null) + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 7, "instr('hello world', 'world') should return 7") + assert(result(1) == 1, "instr('hello world', 'hello') should return 1") + assert(result(2) == 5, "instr('hello world', 'o') should return 5") + assert(result(3) == 0, "instr('hello world', 'z') should return 0") + assert(result(4) == 0, "instr(null, 'test') should return null") + assert(result(5) == 0, "instr('test', null) should return null") + } + + test("test instr function - multiple occurrences") { + val data = Seq( + ("banana", "a"), + ("testtesttest", "test"), + ("abcabcabc", "abc") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 2, "instr('banana', 'a') should return 2") + assert(result(1) == 1, "instr('testtesttest', 'test') should return 1") + assert(result(2) == 1, "instr('abcabcabc', 'abc') should return 1") + } + + test("test instr function - case sensitive") { + val data = Seq( + ("Hello", "hello"), + ("HELLO", "hello"), + ("Hello", "Hello"), + ("hElLo", "hello") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 0, "instr('Hello', 'hello') should return 0 (case sensitive)") + assert(result(1) == 0, "instr('HELLO', 'hello') should return 0 (case sensitive)") + assert(result(2) == 1, "instr('Hello', 'Hello') should return 1") + assert(result(3) == 0, "instr('hElLo', 'hello') should return 0 (case sensitive)") + } + + test("test instr function - with filter") { + val data = Seq( + ("hello world", "world", 1), + ("hello", "world", 0), + ("hello", "hello", 1), + ("test", "abc", 0) + ) + + val df = spark.createDataFrame(data).toDF("str", "substr", "expected") + val result = df + .filter("instr(str, substr) > 0") + .select("str") + .collect() + .map(_.getString(0)) + + assert(result.length == 2, "Should find 2 matching strings") + assert(result.contains("hello world")) + assert(result.contains("hello")) + } + + test("test instr function - in group by") { + val data = Seq( + ("test1", "test"), + ("test2", "test"), + ("hello", "world"), + ("testing", "test") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df + .groupBy("substr") + .count() + .filter("count > 0") + .orderBy("substr") + .collect() + + assert(result.length >= 1) + } + + test("test instr function - in where clause") { + val data = Seq( + ("hello world", "world"), + ("hello", "world"), + ("testing", "test"), + ("abc", "def") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df + .filter("instr(str, substr) = 1") + .select("str") + .collect() + .map(_.getString(0)) + + assert(result.length >= 1) + } +} diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index a65dc0d44..2eeb8d36b 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -26,6 +26,7 @@ mod spark_dates; pub mod spark_get_json_object; mod spark_hash; mod spark_initcap; +mod spark_instr; mod spark_isnan; mod spark_make_array; mod spark_make_decimal; @@ -85,6 +86,7 @@ pub fn create_auron_ext_function( Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero) } "Spark_IsNaN" => Arc::new(spark_isnan::spark_isnan), + "Spark_Instr" => Arc::new(spark_instr::spark_instr), _ => df_unimplemented_err!("spark ext function not implemented: {name}")?, }) } diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs new file mode 100644 index 000000000..b970cc38f --- /dev/null +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, StringArray}; +use datafusion::{ + common::{ + Result, ScalarValue, + cast::{as_int32_array, as_string_array}, + }, + physical_plan::ColumnarValue, +}; +use datafusion_ext_commons::df_execution_err; + +/// instr(str, substr) - Returns the (1-based) index of the first occurrence of +/// substr in str. Compatible with Spark's instr function. +/// Returns 0 if substr is not found or if substr is empty. +/// Returns null if str is null or substr is null. +pub fn spark_instr(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + df_execution_err!("instr requires exactly 2 arguments")?; + } + + let is_scalar = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let len = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }) + .max() + .unwrap_or(0); + + let arrays = args + .iter() + .map(|arg| { + Ok(match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len)?, + }) + }) + .collect::>>()?; + + let str_array = as_string_array(&arrays[0])?; + let substr_array = as_string_array(&arrays[1])?; + + let result_array: ArrayRef = Arc::new(Int32Array::from_iter( + str_array + .iter() + .zip(substr_array.iter()) + .map(|(s, substr)| match (s, substr) { + (Some(_), None) => None, // substr is null + (None, _) => None, // str is null + (Some(s), Some(substr)) => { + if substr.is_empty() { + Some(0) + } else { + Some(s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) + } + } + }), + )); + + if is_scalar { + let scalar = as_int32_array(&result_array)?.value(0); + Ok(ColumnarValue::Scalar(if result_array.is_null(0) { + ScalarValue::Int32(None) + } else { + ScalarValue::Int32(Some(scalar)) + })) + } else { + Ok(ColumnarValue::Array(result_array)) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Int32Array, StringArray}; + use datafusion::{ + common::{Result, ScalarValue, cast::as_int32_array}, + physical_plan::ColumnarValue, + }; + + use super::spark_instr; + + #[test] + fn test_spark_instr() -> Result<()> { + // Test basic functionality with scalar substring + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello world".to_string()), + Some("abc".to_string()), + Some("abcabc".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from("world")), + ])?; + let s = r.into_array(4)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(7), Some(0), Some(0), None,] + ); + + // Test with empty substring should return 0 + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello".to_string()), + Some("world".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from("")), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(0), None,] + ); + + // Test with null substring + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "hello".to_string(), + )]))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ])?; + let s = r.into_array(1)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![None,] + ); + + // Test with array substring (element-wise) + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello world".to_string()), + Some("hello".to_string()), + Some("test".to_string()), + ]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("world".to_string()), + Some("test".to_string()), + Some("test".to_string()), + ]))), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(7), Some(0), Some(1),] + ); + + // Test with both scalars + let r = spark_instr(&vec![ + ColumnarValue::Scalar(ScalarValue::from("hello world")), + ColumnarValue::Scalar(ScalarValue::from("world")), + ])?; + assert!(matches!( + r, + ColumnarValue::Scalar(ScalarValue::Int32(Some(7))) + )); + + Ok(()) + } + + #[test] + fn test_spark_instr_multiple_matches() -> Result<()> { + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("banana".to_string()), + Some("testtesttest".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("test")), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(1),] + ); + Ok(()) + } + + #[test] + fn test_spark_instr_case_sensitive() -> Result<()> { + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("Hello".to_string()), + Some("HELLO".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("hello")), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(0),] + ); + Ok(()) + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 7a3bde2c8..11ad3797f 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -922,6 +922,8 @@ object NativeConverters extends Logging { case e: Levenshtein => buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children, e.dataType) + case e: StringInstr => + buildExtScalarFunction("Spark_Instr", e.children, e.dataType) case e: Hour if datetimeExtractEnabled => buildTimePartExt("Spark_Hour", e.children.head, isPruningExpr, fallback)