Skip to content

Commit 9227bff

Browse files
richoxzhangli20
andauthored
fix spark_xxhash64 + literal error (#920)
Co-authored-by: zhangli20 <zhangli20@kuaishou.com>
1 parent a3c117c commit 9227bff

3 files changed

Lines changed: 81 additions & 111 deletions

File tree

native-engine/datafusion-ext-functions/src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@ mod brickhouse;
2323
mod spark_check_overflow;
2424
mod spark_dates;
2525
pub mod spark_get_json_object;
26+
mod spark_hash;
2627
mod spark_make_array;
2728
mod spark_make_decimal;
28-
mod spark_murmur3_hash;
2929
mod spark_null_if;
3030
mod spark_strings;
3131
mod spark_unscaled_value;
32-
mod spark_xxhash64;
3332

3433
pub fn create_spark_ext_function(name: &str) -> Result<ScalarFunctionImplementation> {
3534
Ok(match name {
@@ -39,8 +38,8 @@ pub fn create_spark_ext_function(name: &str) -> Result<ScalarFunctionImplementat
3938
"UnscaledValue" => Arc::new(spark_unscaled_value::spark_unscaled_value),
4039
"MakeDecimal" => Arc::new(spark_make_decimal::spark_make_decimal),
4140
"CheckOverflow" => Arc::new(spark_check_overflow::spark_check_overflow),
42-
"Murmur3Hash" => Arc::new(spark_murmur3_hash::spark_murmur3_hash),
43-
"XxHash64" => Arc::new(spark_xxhash64::spark_xxhash64),
41+
"Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash),
42+
"XxHash64" => Arc::new(spark_hash::spark_xxhash64),
4443
"GetJsonObject" => Arc::new(spark_get_json_object::spark_get_json_object),
4544
"GetParsedJsonObject" => Arc::new(spark_get_json_object::spark_get_parsed_json_object),
4645
"ParseJson" => Arc::new(spark_get_json_object::spark_parse_json),

native-engine/datafusion-ext-functions/src/spark_xxhash64.rs renamed to native-engine/datafusion-ext-functions/src/spark_hash.rs

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,47 @@
1515
use std::sync::Arc;
1616

1717
use arrow::array::*;
18-
use datafusion::{common::Result, physical_plan::ColumnarValue};
19-
use datafusion_ext_commons::spark_hash::create_xxhash64_hashes;
18+
use datafusion::{
19+
common::{Result, ScalarValue},
20+
physical_plan::ColumnarValue,
21+
};
22+
use datafusion_ext_commons::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes};
23+
24+
/// implements org.apache.spark.sql.catalyst.expressions.Murmur3Hash
25+
pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result<ColumnarValue> {
26+
spark_hash(args, |len, is_scalar, arrays| {
27+
// use identical seed as spark hash partition
28+
let spark_murmur3_default_seed = 42i32;
29+
let hash_buffer = create_murmur3_hashes(len, &arrays, spark_murmur3_default_seed);
30+
if is_scalar {
31+
ColumnarValue::Scalar(ScalarValue::from(hash_buffer[0]))
32+
} else {
33+
ColumnarValue::Array(Arc::new(Int32Array::from(hash_buffer)))
34+
}
35+
})
36+
}
2037

2138
/// implements org.apache.spark.sql.catalyst.expressions.XxHash64
2239
pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result<ColumnarValue> {
40+
spark_hash(args, |len, is_scalar, arrays| {
41+
// use identical seed as spark hash partition
42+
let spark_xxhash64_default_seed = 42i64;
43+
let hash_buffer = create_xxhash64_hashes(len, arrays, spark_xxhash64_default_seed);
44+
if is_scalar {
45+
ColumnarValue::Scalar(ScalarValue::from(hash_buffer[0]))
46+
} else {
47+
ColumnarValue::Array(Arc::new(Int64Array::from(hash_buffer)))
48+
}
49+
})
50+
}
51+
52+
pub fn spark_hash(
53+
args: &[ColumnarValue],
54+
hash_impl: impl Fn(usize, bool, &[ArrayRef]) -> ColumnarValue,
55+
) -> Result<ColumnarValue> {
56+
let is_scalar = args
57+
.iter()
58+
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
2359
let len = args
2460
.iter()
2561
.map(|arg| match arg {
@@ -38,14 +74,7 @@ pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result<ColumnarValue> {
3874
})
3975
})
4076
.collect::<Result<Vec<_>>>()?;
41-
42-
// use identical seed as spark hash partition
43-
let spark_xxhash64_default_seed = 42i64;
44-
let hash_buffer = create_xxhash64_hashes(len, &arrays, spark_xxhash64_default_seed);
45-
46-
Ok(ColumnarValue::Array(Arc::new(Int64Array::from(
47-
hash_buffer,
48-
))))
77+
Ok(hash_impl(len, is_scalar, &arrays))
4978
}
5079

5180
#[cfg(test)]
@@ -57,6 +86,45 @@ mod test {
5786

5887
use super::*;
5988

89+
#[test]
90+
fn test_murmur3_hash_int64() -> Result<(), Box<dyn Error>> {
91+
let result = spark_murmur3_hash(&vec![ColumnarValue::Array(Arc::new(Int64Array::from(
92+
vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)],
93+
)))])?
94+
.into_array(5)?;
95+
96+
let expected = Int32Array::from(vec![
97+
Some(-1712319331),
98+
Some(-1670924195),
99+
Some(-939490007),
100+
Some(-1604625029),
101+
Some(-853646085),
102+
]);
103+
let expected: ArrayRef = Arc::new(expected);
104+
105+
assert_eq!(&result, &expected);
106+
Ok(())
107+
}
108+
109+
#[test]
110+
fn test_murmur3_hash_string() -> Result<(), Box<dyn Error>> {
111+
let result = spark_murmur3_hash(&vec![ColumnarValue::Array(Arc::new(
112+
StringArray::from_iter_values(["hello", "bar", "", "😁", "天地"]),
113+
))])?
114+
.into_array(5)?;
115+
116+
let expected = Int32Array::from(vec![
117+
Some(-1008564952),
118+
Some(-1808790533),
119+
Some(142593372),
120+
Some(885025535),
121+
Some(-1899966402),
122+
]);
123+
let expected: ArrayRef = Arc::new(expected);
124+
125+
assert_eq!(&result, &expected);
126+
Ok(())
127+
}
60128
#[test]
61129
fn test_xxhash64_int64() -> Result<(), Box<dyn Error>> {
62130
let result = spark_xxhash64(&vec![ColumnarValue::Array(Arc::new(Int64Array::from(

native-engine/datafusion-ext-functions/src/spark_murmur3_hash.rs

Lines changed: 0 additions & 97 deletions
This file was deleted.

0 commit comments

Comments
 (0)