Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions native-engine/datafusion-ext-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ mod brickhouse;
mod spark_check_overflow;
mod spark_dates;
pub mod spark_get_json_object;
mod spark_hash;
mod spark_make_array;
mod spark_make_decimal;
mod spark_murmur3_hash;
mod spark_null_if;
mod spark_strings;
mod spark_unscaled_value;
mod spark_xxhash64;

pub fn create_spark_ext_function(name: &str) -> Result<ScalarFunctionImplementation> {
Ok(match name {
Expand All @@ -39,8 +38,8 @@ pub fn create_spark_ext_function(name: &str) -> Result<ScalarFunctionImplementat
"UnscaledValue" => Arc::new(spark_unscaled_value::spark_unscaled_value),
"MakeDecimal" => Arc::new(spark_make_decimal::spark_make_decimal),
"CheckOverflow" => Arc::new(spark_check_overflow::spark_check_overflow),
"Murmur3Hash" => Arc::new(spark_murmur3_hash::spark_murmur3_hash),
"XxHash64" => Arc::new(spark_xxhash64::spark_xxhash64),
"Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash),
"XxHash64" => Arc::new(spark_hash::spark_xxhash64),
"GetJsonObject" => Arc::new(spark_get_json_object::spark_get_json_object),
"GetParsedJsonObject" => Arc::new(spark_get_json_object::spark_get_parsed_json_object),
"ParseJson" => Arc::new(spark_get_json_object::spark_parse_json),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,47 @@
use std::sync::Arc;

use arrow::array::*;
use datafusion::{common::Result, physical_plan::ColumnarValue};
use datafusion_ext_commons::spark_hash::create_xxhash64_hashes;
use datafusion::{
common::{Result, ScalarValue},
physical_plan::ColumnarValue,
};
use datafusion_ext_commons::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes};

/// implements org.apache.spark.sql.catalyst.expressions.Murmur3Hash
pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result<ColumnarValue> {
spark_hash(args, |len, is_scalar, arrays| {
// use identical seed as spark hash partition
let spark_murmur3_default_seed = 42i32;
let hash_buffer = create_murmur3_hashes(len, &arrays, spark_murmur3_default_seed);
if is_scalar {
ColumnarValue::Scalar(ScalarValue::from(hash_buffer[0]))
} else {
ColumnarValue::Array(Arc::new(Int32Array::from(hash_buffer)))
}
})
}

/// implements org.apache.spark.sql.catalyst.expressions.XxHash64
pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result<ColumnarValue> {
spark_hash(args, |len, is_scalar, arrays| {
// use identical seed as spark hash partition
let spark_xxhash64_default_seed = 42i64;
let hash_buffer = create_xxhash64_hashes(len, arrays, spark_xxhash64_default_seed);
if is_scalar {
ColumnarValue::Scalar(ScalarValue::from(hash_buffer[0]))
} else {
ColumnarValue::Array(Arc::new(Int64Array::from(hash_buffer)))
}
})
}

pub fn spark_hash(
args: &[ColumnarValue],
hash_impl: impl Fn(usize, bool, &[ArrayRef]) -> ColumnarValue,
) -> Result<ColumnarValue> {
let is_scalar = args
.iter()
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
let len = args
.iter()
.map(|arg| match arg {
Expand All @@ -38,14 +74,7 @@ pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result<ColumnarValue> {
})
})
.collect::<Result<Vec<_>>>()?;

// use identical seed as spark hash partition
let spark_xxhash64_default_seed = 42i64;
let hash_buffer = create_xxhash64_hashes(len, &arrays, spark_xxhash64_default_seed);

Ok(ColumnarValue::Array(Arc::new(Int64Array::from(
hash_buffer,
))))
Ok(hash_impl(len, is_scalar, &arrays))
}

#[cfg(test)]
Expand All @@ -57,6 +86,45 @@ mod test {

use super::*;

#[test]
fn test_murmur3_hash_int64() -> Result<(), Box<dyn Error>> {
let result = spark_murmur3_hash(&vec![ColumnarValue::Array(Arc::new(Int64Array::from(
vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)],
)))])?
.into_array(5)?;

let expected = Int32Array::from(vec![
Some(-1712319331),
Some(-1670924195),
Some(-939490007),
Some(-1604625029),
Some(-853646085),
]);
let expected: ArrayRef = Arc::new(expected);

assert_eq!(&result, &expected);
Ok(())
}

#[test]
fn test_murmur3_hash_string() -> Result<(), Box<dyn Error>> {
let result = spark_murmur3_hash(&vec![ColumnarValue::Array(Arc::new(
StringArray::from_iter_values(["hello", "bar", "", "😁", "天地"]),
))])?
.into_array(5)?;

let expected = Int32Array::from(vec![
Some(-1008564952),
Some(-1808790533),
Some(142593372),
Some(885025535),
Some(-1899966402),
]);
let expected: ArrayRef = Arc::new(expected);

assert_eq!(&result, &expected);
Ok(())
}
#[test]
fn test_xxhash64_int64() -> Result<(), Box<dyn Error>> {
let result = spark_xxhash64(&vec![ColumnarValue::Array(Arc::new(Int64Array::from(
Expand Down
97 changes: 0 additions & 97 deletions native-engine/datafusion-ext-functions/src/spark_murmur3_hash.rs

This file was deleted.

Loading