Skip to content

Commit ce50edf

Browse files
committed
make MD5 output back to utf8
1 parent 1539192 commit ce50edf

File tree

6 files changed

+63
-21
lines changed

6 files changed

+63
-21
lines changed

native-engine/auron-serde/proto/auron.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ enum ScalarFunction {
234234
Lpad=32;
235235
Lower=33;
236236
Ltrim=34;
237-
MD5=35;
237+
// MD5=35;
238238
// NullIf=36;
239239
OctetLength=37;
240240
Random=38;

native-engine/auron-serde/src/from_proto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
786786
ScalarFunction::NullIf => f::core::nullif(),
787787
ScalarFunction::DatePart => f::datetime::date_part(),
788788
ScalarFunction::DateTrunc => f::datetime::date_trunc(),
789-
ScalarFunction::Md5 => f::crypto::md5(),
789+
// ScalarFunction::Md5 => f::crypto::md5(),
790790
// ScalarFunction::Sha224 => f::crypto::sha224(),
791791
// ScalarFunction::Sha256 => f::crypto::sha256(),
792792
// ScalarFunction::Sha384 => f::crypto::sha384(),

native-engine/datafusion-ext-functions/src/lib.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use datafusion_ext_commons::df_unimplemented_err;
2020

2121
mod brickhouse;
2222
mod spark_check_overflow;
23+
mod spark_crypto;
2324
mod spark_dates;
2425
pub mod spark_get_json_object;
2526
mod spark_hash;
@@ -28,7 +29,6 @@ mod spark_make_decimal;
2829
mod spark_normalize_nan_and_zero;
2930
mod spark_null_if;
3031
mod spark_round;
31-
mod spark_sha2;
3232
mod spark_strings;
3333
mod spark_unscaled_value;
3434

@@ -42,10 +42,11 @@ pub fn create_spark_ext_function(name: &str) -> Result<ScalarFunctionImplementat
4242
"CheckOverflow" => Arc::new(spark_check_overflow::spark_check_overflow),
4343
"Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash),
4444
"XxHash64" => Arc::new(spark_hash::spark_xxhash64),
45-
"Sha224" => Arc::new(spark_sha2::spark_sha224),
46-
"Sha256" => Arc::new(spark_sha2::spark_sha256),
47-
"Sha384" => Arc::new(spark_sha2::spark_sha384),
48-
"Sha512" => Arc::new(spark_sha2::spark_sha512),
45+
"Sha224" => Arc::new(spark_crypto::spark_sha224),
46+
"Sha256" => Arc::new(spark_crypto::spark_sha256),
47+
"Sha384" => Arc::new(spark_crypto::spark_sha384),
48+
"Sha512" => Arc::new(spark_crypto::spark_sha512),
49+
"MD5" => Arc::new(spark_crypto::spark_md5),
4950
"GetJsonObject" => Arc::new(spark_get_json_object::spark_get_json_object),
5051
"GetParsedJsonObject" => Arc::new(spark_get_json_object::spark_get_parsed_json_object),
5152
"ParseJson" => Arc::new(spark_get_json_object::spark_parse_json),

native-engine/datafusion-ext-functions/src/spark_sha2.rs renamed to native-engine/datafusion-ext-functions/src/spark_crypto.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ use arrow::{
2020
datatypes::{DataType, Field},
2121
};
2222
use datafusion::{
23-
common::{Result, ScalarValue, cast::as_binary_array},
24-
functions::crypto::{sha224, sha256, sha384, sha512},
23+
common::{Result, ScalarValue, cast::as_binary_array, utils::take_function_args},
24+
functions::crypto::{
25+
basic::{DigestAlgorithm, digest_process},
26+
sha224, sha256, sha384, sha512,
27+
},
2528
logical_expr::{ScalarFunctionArgs, ScalarUDF},
2629
physical_plan::ColumnarValue,
2730
};
@@ -30,39 +33,46 @@ use datafusion_ext_commons::df_execution_err;
3033
/// `sha224` function that simulates Spark's `sha2` expression with bit width
3134
/// 224
3235
pub fn spark_sha224(args: &[ColumnarValue]) -> Result<ColumnarValue> {
33-
wrap_digest_result_as_hex_string(args, sha224())
36+
digest_and_wrap_as_hex(args, sha224())
3437
}
3538

3639
/// `sha256` function that simulates Spark's `sha2` expression with bit width 0
3740
/// or 256
3841
pub fn spark_sha256(args: &[ColumnarValue]) -> Result<ColumnarValue> {
39-
wrap_digest_result_as_hex_string(args, sha256())
42+
digest_and_wrap_as_hex(args, sha256())
4043
}
4144

4245
/// `sha384` function that simulates Spark's `sha2` expression with bit width
4346
/// 384
4447
pub fn spark_sha384(args: &[ColumnarValue]) -> Result<ColumnarValue> {
45-
wrap_digest_result_as_hex_string(args, sha384())
48+
digest_and_wrap_as_hex(args, sha384())
4649
}
4750

4851
/// `sha512` function that simulates Spark's `sha2` expression with bit width
4952
/// 512
5053
pub fn spark_sha512(args: &[ColumnarValue]) -> Result<ColumnarValue> {
51-
wrap_digest_result_as_hex_string(args, sha512())
54+
digest_and_wrap_as_hex(args, sha512())
5255
}
5356

54-
/// Spark requires hex string as the result of sha2 functions, we have to wrap
55-
/// the result of digest functions as hex string
56-
fn wrap_digest_result_as_hex_string(
57-
args: &[ColumnarValue],
58-
digest: Arc<ScalarUDF>,
59-
) -> Result<ColumnarValue> {
57+
pub fn spark_md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
58+
let [data] = take_function_args("md5", args)?;
59+
let value = digest_process(data, DigestAlgorithm::Md5)?;
60+
to_hex_string(value)
61+
}
62+
63+
/// Spark requires hex string as the result of sha2 and md5 functions, we have
64+
/// to wrap the result of digest functions as hex string
65+
fn digest_and_wrap_as_hex(args: &[ColumnarValue], digest: Arc<ScalarUDF>) -> Result<ColumnarValue> {
6066
let value = digest.inner().invoke_with_args(ScalarFunctionArgs {
6167
args: args.to_vec(),
6268
arg_fields: vec![Arc::new(Field::new("arg", DataType::Binary, true))],
6369
number_rows: 0,
6470
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
6571
})?;
72+
to_hex_string(value)
73+
}
74+
75+
fn to_hex_string(value: ColumnarValue) -> Result<ColumnarValue> {
6676
Ok(match value {
6777
ColumnarValue::Array(array) => {
6878
let binary_array = as_binary_array(&array)?;
@@ -102,7 +112,7 @@ mod tests {
102112
common::ScalarValue, error::Result as DataFusionResult, physical_plan::ColumnarValue,
103113
};
104114

105-
use crate::spark_sha2::{spark_sha224, spark_sha256, spark_sha384, spark_sha512};
115+
use crate::spark_crypto::{spark_md5, spark_sha224, spark_sha256, spark_sha384, spark_sha512};
106116

107117
/// Helper function to run a test for a given hash function and scalar
108118
/// input.
@@ -181,4 +191,18 @@ mod tests {
181191
let expected = "178d767c364244ede054ebb3cc4af0ac2b307a86fba6a32706ce4f692642674d2ab8f51ee738ecb09bc296918aa85db48abe28fcaef7aa2da81a618cc6d891c3";
182192
run_scalar_test(spark_sha512, input, expected)
183193
}
194+
195+
#[test]
196+
fn test_md5_scalar_utf8() -> Result<(), Box<dyn Error>> {
197+
let input = ColumnarValue::Scalar(ScalarValue::Utf8(Some("ABC".to_string())));
198+
let expected = "902fbdd2b1df0c4f70b4a5d23525e932";
199+
run_scalar_test(spark_md5, input, expected)
200+
}
201+
202+
#[test]
203+
fn test_md5_scalar_binary() -> Result<(), Box<dyn Error>> {
204+
let input = ColumnarValue::Scalar(ScalarValue::Binary(Some(vec![1, 2, 3, 4, 5, 6])));
205+
let expected = "6ac1e56bc78f031059be7be854522c4c";
206+
run_scalar_test(spark_md5, input, expected)
207+
}
184208
}

spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ class AuronFunctionSuite
6262
}
6363
}
6464

65+
test("md5 function") {
66+
withTable("t1") {
67+
sql("create table t1 using parquet as select 'spark' as c1, '3.x' as version")
68+
val functions =
69+
"""
70+
|select b.md5
71+
|from (
72+
| select c1, version from t1
73+
|) a join (
74+
| select md5(concat(c1, version)) as md5 from t1
75+
|) b on md5(concat(a.c1, a.version)) = b.md5
76+
|""".stripMargin
77+
val df = sql(functions)
78+
checkAnswer(df, Seq(Row("9ff36a3857e29335d03cf6bef2147119")))
79+
}
80+
}
81+
6582
test("spark hash function") {
6683
withTable("t1") {
6784
sql("create table t1 using parquet as select array(1, 2) as arr")

spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,7 @@ object NativeConverters extends Logging {
855855
case e @ NullIf(left, right, _) =>
856856
buildExtScalarFunction("NullIf", left :: right :: Nil, e.dataType)
857857
case Md5(_1) =>
858-
buildScalarFunction(pb.ScalarFunction.MD5, Seq(unpackBinaryTypeCast(_1)), StringType)
858+
buildExtScalarFunction("MD5", Seq(unpackBinaryTypeCast(_1)), StringType)
859859
case Reverse(_1) =>
860860
buildScalarFunction(pb.ScalarFunction.Reverse, Seq(unpackBinaryTypeCast(_1)), StringType)
861861
case InitCap(_1) =>

0 commit comments

Comments
 (0)