@@ -20,8 +20,11 @@ use arrow::{
2020 datatypes:: { DataType , Field } ,
2121} ;
2222use datafusion:: {
23- common:: { Result , ScalarValue , cast:: as_binary_array} ,
24- functions:: crypto:: { sha224, sha256, sha384, sha512} ,
23+ common:: { Result , ScalarValue , cast:: as_binary_array, utils:: take_function_args} ,
24+ functions:: crypto:: {
25+ basic:: { DigestAlgorithm , digest_process} ,
26+ sha224, sha256, sha384, sha512,
27+ } ,
2528 logical_expr:: { ScalarFunctionArgs , ScalarUDF } ,
2629 physical_plan:: ColumnarValue ,
2730} ;
@@ -30,39 +33,46 @@ use datafusion_ext_commons::df_execution_err;
3033/// `sha224` function that simulates Spark's `sha2` expression with bit width
3134/// 224
3235pub fn spark_sha224 ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
33- wrap_digest_result_as_hex_string ( args, sha224 ( ) )
36+ digest_and_wrap_as_hex ( args, sha224 ( ) )
3437}
3538
3639/// `sha256` function that simulates Spark's `sha2` expression with bit width 0
3740/// or 256
3841pub fn spark_sha256 ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
39- wrap_digest_result_as_hex_string ( args, sha256 ( ) )
42+ digest_and_wrap_as_hex ( args, sha256 ( ) )
4043}
4144
4245/// `sha384` function that simulates Spark's `sha2` expression with bit width
4346/// 384
4447pub fn spark_sha384 ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
45- wrap_digest_result_as_hex_string ( args, sha384 ( ) )
48+ digest_and_wrap_as_hex ( args, sha384 ( ) )
4649}
4750
4851/// `sha512` function that simulates Spark's `sha2` expression with bit width
4952/// 512
5053pub fn spark_sha512 ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
51- wrap_digest_result_as_hex_string ( args, sha512 ( ) )
54+ digest_and_wrap_as_hex ( args, sha512 ( ) )
5255}
5356
54- /// Spark requires hex string as the result of sha2 functions, we have to wrap
55- /// the result of digest functions as hex string
56- fn wrap_digest_result_as_hex_string (
57- args : & [ ColumnarValue ] ,
58- digest : Arc < ScalarUDF > ,
59- ) -> Result < ColumnarValue > {
57+ pub fn spark_md5 ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
58+ let [ data] = take_function_args ( "md5" , args) ?;
59+ let value = digest_process ( data, DigestAlgorithm :: Md5 ) ?;
60+ to_hex_string ( value)
61+ }
62+
63+ /// Spark requires hex string as the result of sha2 and md5 functions, we have
64+ /// to wrap the result of digest functions as hex string
65+ fn digest_and_wrap_as_hex ( args : & [ ColumnarValue ] , digest : Arc < ScalarUDF > ) -> Result < ColumnarValue > {
6066 let value = digest. inner ( ) . invoke_with_args ( ScalarFunctionArgs {
6167 args : args. to_vec ( ) ,
6268 arg_fields : vec ! [ Arc :: new( Field :: new( "arg" , DataType :: Binary , true ) ) ] ,
6369 number_rows : 0 ,
6470 return_field : Arc :: new ( Field :: new ( "result" , DataType :: Utf8 , true ) ) ,
6571 } ) ?;
72+ to_hex_string ( value)
73+ }
74+
75+ fn to_hex_string ( value : ColumnarValue ) -> Result < ColumnarValue > {
6676 Ok ( match value {
6777 ColumnarValue :: Array ( array) => {
6878 let binary_array = as_binary_array ( & array) ?;
@@ -102,7 +112,7 @@ mod tests {
102112 common:: ScalarValue , error:: Result as DataFusionResult , physical_plan:: ColumnarValue ,
103113 } ;
104114
105- use crate :: spark_sha2 :: { spark_sha224, spark_sha256, spark_sha384, spark_sha512} ;
115+ use crate :: spark_crypto :: { spark_md5 , spark_sha224, spark_sha256, spark_sha384, spark_sha512} ;
106116
107117 /// Helper function to run a test for a given hash function and scalar
108118 /// input.
@@ -181,4 +191,18 @@ mod tests {
181191 let expected = "178d767c364244ede054ebb3cc4af0ac2b307a86fba6a32706ce4f692642674d2ab8f51ee738ecb09bc296918aa85db48abe28fcaef7aa2da81a618cc6d891c3" ;
182192 run_scalar_test ( spark_sha512, input, expected)
183193 }
194+
195+ #[ test]
196+ fn test_md5_scalar_utf8 ( ) -> Result < ( ) , Box < dyn Error > > {
197+ let input = ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( "ABC" . to_string ( ) ) ) ) ;
198+ let expected = "902fbdd2b1df0c4f70b4a5d23525e932" ;
199+ run_scalar_test ( spark_md5, input, expected)
200+ }
201+
202+ #[ test]
203+ fn test_md5_scalar_binary ( ) -> Result < ( ) , Box < dyn Error > > {
204+ let input = ColumnarValue :: Scalar ( ScalarValue :: Binary ( Some ( vec ! [ 1 , 2 , 3 , 4 , 5 , 6 ] ) ) ) ;
205+ let expected = "6ac1e56bc78f031059be7be854522c4c" ;
206+ run_scalar_test ( spark_md5, input, expected)
207+ }
184208}
0 commit comments