diff --git a/native-engine/auron-serde/proto/auron.proto b/native-engine/auron-serde/proto/auron.proto index 29e9f1134..788be3526 100644 --- a/native-engine/auron-serde/proto/auron.proto +++ b/native-engine/auron-serde/proto/auron.proto @@ -468,6 +468,7 @@ message BroadcastJoinExecNode { JoinType join_type = 5; JoinSide broadcast_side = 6; string cached_build_hash_map_id = 7; + bool is_null_aware_anti_join = 8; } message RenameColumnsExecNode { diff --git a/native-engine/auron-serde/src/from_proto.rs b/native-engine/auron-serde/src/from_proto.rs index 0caaad6ca..82a4b03b8 100644 --- a/native-engine/auron-serde/src/from_proto.rs +++ b/native-engine/auron-serde/src/from_proto.rs @@ -219,6 +219,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .map_err(|_| proto_error("invalid BuildSide"))?, false, None, + false, )?)) } PhysicalPlanType::SortMergeJoin(sort_merge_join) => { @@ -354,6 +355,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .expect("invalid BroadcastSide"); let cached_build_hash_map_id = broadcast_join.cached_build_hash_map_id.clone(); + let is_null_aware_anti_join = broadcast_join.is_null_aware_anti_join; Ok(Arc::new(BroadcastJoinExec::try_new( schema, @@ -368,6 +370,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .map_err(|_| proto_error("invalid BroadcastSide"))?, true, Some(cached_build_hash_map_id), + is_null_aware_anti_join, )?)) } PhysicalPlanType::Union(union) => { diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs index 276f5e094..fef3397bf 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs @@ -88,6 +88,7 @@ pub struct BroadcastJoinExec { schema: SchemaRef, is_built: bool, // true for BroadcastHashJoin, false for ShuffledHashJoin cached_build_hash_map_id: Option, + is_null_aware_anti_join: bool, metrics: ExecutionPlanMetricsSet, props: OnceCell, } @@ -102,6 +103,7 @@ impl BroadcastJoinExec { broadcast_side: JoinSide, is_built: bool, cached_build_hash_map_id: Option, + is_null_aware_anti_join: bool, ) -> Result { Ok(Self { left, @@ -112,6 +114,7 @@ impl BroadcastJoinExec { schema, is_built, cached_build_hash_map_id, + is_null_aware_anti_join, metrics: ExecutionPlanMetricsSet::new(), props: OnceCell::new(), }) @@ -176,6 +179,7 @@ impl BroadcastJoinExec { sort_options: vec![SortOptions::default(); self.on.len()], projection, key_data_types, + is_null_aware_anti_join: self.is_null_aware_anti_join, }) } @@ -279,6 +283,7 @@ impl ExecutionPlan for BroadcastJoinExec { self.broadcast_side, self.is_built, None, + self.is_null_aware_anti_join, )?)) } diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs index 41ebcf6fd..1018b72ae 100644 --- a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs @@ -193,7 +193,11 @@ impl Joiner for SemiJoiner

{ .as_ref() .map(|nb| nb.is_valid(row_idx)) .unwrap_or(true); - if P.mode == Anti && P.probe_is_join_side && !key_is_valid { + if P.mode == Anti + && P.probe_is_join_side + && !key_is_valid + && self.join_params.is_null_aware_anti_join + { probed_joined.set(row_idx, true); continue; } diff --git a/native-engine/datafusion-ext-plans/src/joins/mod.rs b/native-engine/datafusion-ext-plans/src/joins/mod.rs index 5f8ae9973..6ccc40866 100644 --- a/native-engine/datafusion-ext-plans/src/joins/mod.rs +++ b/native-engine/datafusion-ext-plans/src/joins/mod.rs @@ -46,6 +46,7 @@ pub struct JoinParams { pub sort_options: Vec, pub projection: JoinProjection, pub batch_size: usize, + pub is_null_aware_anti_join: bool, } #[derive(Debug, Clone)] diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs b/native-engine/datafusion-ext-plans/src/joins/test.rs index 671ecd732..2c50cabb9 100644 --- a/native-engine/datafusion-ext-plans/src/joins/test.rs +++ b/native-engine/datafusion-ext-plans/src/joins/test.rs @@ -219,6 +219,7 @@ mod tests { JoinSide::Right, true, None, + false, )?) } BHJRightProbed => { @@ -235,6 +236,7 @@ mod tests { JoinSide::Left, true, None, + false, )?) } SHJLeftProbed => Arc::new(BroadcastJoinExec::try_new( @@ -246,6 +248,7 @@ mod tests { JoinSide::Right, false, None, + false, )?), SHJRightProbed => Arc::new(BroadcastJoinExec::try_new( schema, @@ -256,6 +259,7 @@ mod tests { JoinSide::Left, false, None, + false, )?), }; let columns = columns(&join.schema()); @@ -617,21 +621,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?), )]; - for test_type in [BHJLeftProbed, SHJLeftProbed] { - let (_, batches) = - join_collect(test_type, left.clone(), right.clone(), on.clone(), LeftAnti).await?; - let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| | 6 | 9 |", - "| 5 | 8 | 11 |", - "+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - } - - for test_type in [SMJ, BHJRightProbed, SHJRightProbed] { + for test_type in ALL_TEST_TYPE { let (_, batches) = join_collect(test_type, left.clone(), right.clone(), on.clone(), LeftAnti).await?; let expected = vec![ diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs index 78eda5b62..91496c497 100644 --- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs @@ -128,6 +128,7 @@ impl SortMergeJoinExec { sort_options: self.sort_options.clone(), projection, batch_size: batch_size(), + is_null_aware_anti_join: false, }) } diff --git a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala index 3acbbed92..9cecb8692 100644 --- a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala +++ b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala @@ -229,7 +229,8 @@ class ShimsImpl extends Shims with Logging { leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - broadcastSide: BroadcastSide): NativeBroadcastJoinBase = + broadcastSide: BroadcastSide, + isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase = NativeBroadcastJoinExec( left, right, @@ -237,7 +238,8 @@ class ShimsImpl extends Shims with Logging { leftKeys, rightKeys, joinType, - broadcastSide) + broadcastSide, + isNullAwareAntiJoin) override def createNativeSortMergeJoinExec( left: SparkPlan, diff --git a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala index d0c2cea8a..9ac6e893e 100644 --- a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala +++ b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala @@ -35,7 +35,8 @@ case class NativeBroadcastJoinExec( override val leftKeys: Seq[Expression], override val rightKeys: Seq[Expression], override val joinType: JoinType, - broadcastSide: BroadcastSide) + broadcastSide: BroadcastSide, + isNullAwareAntiJoin: Boolean) extends NativeBroadcastJoinBase( left, right, @@ -43,7 +44,8 @@ case class NativeBroadcastJoinExec( leftKeys, rightKeys, joinType, - broadcastSide) + broadcastSide, + isNullAwareAntiJoin) with HashJoin { override val condition: Option[Expression] = None diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala index 3a2cc9cfa..8fe2a3e31 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala @@ -581,4 +581,85 @@ class AuronQuerySuite extends AuronQueryTest with BaseAuronSQLSuite with AuronSQ } } } + + test("standard LEFT ANTI JOIN includes NULL keys") { + // This test verifies that standard LEFT ANTI JOIN correctly includes NULL keys + // NULL keys should be in the result because NULL never matches anything + withTable("left_table", "right_table") { + sql(""" + |CREATE TABLE left_table using parquet AS + |SELECT * FROM VALUES + | (1, 2.0), + | (1, 2.0), + | (2, 1.0), + | (2, 1.0), + | (3, 3.0), + | (null, null), + | (null, 5.0), + | (6, null) + |AS t(a, b) + |""".stripMargin) + + sql(""" + |CREATE TABLE right_table using parquet AS + |SELECT * FROM VALUES + | (2, 3.0), + | (2, 3.0), + | (3, 2.0), + | (4, 1.0), + | (null, null), + | (null, 5.0), + | (6, null) + |AS t(c, d) + |""".stripMargin) + + // Standard LEFT ANTI JOIN should include rows with NULL keys + // Expected: (1, 2.0), (1, 2.0), (null, null), (null, 5.0) + checkSparkAnswer( + "SELECT * FROM left_table LEFT ANTI JOIN right_table ON left_table.a = right_table.c") + } + } + + test("left join with NOT IN subquery should filter NULL values") { + // This test verifies the fix for the NULL handling issue in Anti join. + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { + val query = + """ + |WITH t2 AS ( + | -- Large table: 100000 rows (0..99999) + | SELECT id AS loan_req_no + | FROM range(0, 100000) + |), + |t1 AS ( + | -- Small table: 10 rows that can match t2 + | SELECT * FROM VALUES + | (1, 'A'), + | (2, 'B'), + | (3, 'C'), + | (4, 'D'), + | (5, 'E'), + | (6, 'F'), + | (7, 'G'), + | (8, 'H'), + | (9, 'I'), + | (10,'J') + | AS t1(loan_req_no, partner_code) + |), + |blk AS ( + | SELECT * FROM VALUES + | ('B'), + | ('Z') + | AS blk(code) + |) + |SELECT + | COUNT(*) AS cnt + |FROM t2 + |LEFT JOIN t1 + | ON t1.loan_req_no = t2.loan_req_no + |WHERE t1.partner_code NOT IN (SELECT code FROM blk) + |""".stripMargin + + checkSparkAnswer(query) + } + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala index 413ad7be5..491f85da2 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala @@ -664,16 +664,23 @@ object AuronConverters extends Logging { } } + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") + def isNullAwareAntiJoin(exec: BroadcastHashJoinExec): Boolean = exec.isNullAwareAntiJoin + + @sparkver("3.0") + def isNullAwareAntiJoin(exec: BroadcastHashJoinExec): Boolean = false + def convertBroadcastHashJoinExec(exec: BroadcastHashJoinExec): SparkPlan = { try { - val (leftKeys, rightKeys, joinType, buildSide, condition, left, right) = ( + val (leftKeys, rightKeys, joinType, buildSide, condition, left, right, naaj) = ( exec.leftKeys, exec.rightKeys, exec.joinType, exec.buildSide, exec.condition, exec.left, - exec.right) + exec.right, + isNullAwareAntiJoin(exec)) logDebugPlanConversion( exec, Seq( @@ -702,7 +709,8 @@ object AuronConverters extends Logging { buildSide match { case BuildLeft => BroadcastLeft case BuildRight => BroadcastRight - }) + }, + naaj) } catch { case e @ (_: NotImplementedError | _: Exception) => @@ -744,7 +752,8 @@ object AuronConverters extends Logging { buildSide match { case BuildLeft => BroadcastLeft case BuildRight => BroadcastRight - }) + }, + isNullAwareAntiJoin = false) } catch { case e @ (_: NotImplementedError | _: Exception) => diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala index a192e1982..a0dd37ae2 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala @@ -86,7 +86,8 @@ abstract class Shims { leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - broadcastSide: BroadcastSide): NativeBroadcastJoinBase + broadcastSide: BroadcastSide, + isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase def createNativeSortMergeJoinExec( left: SparkPlan, diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala index dabeba3f2..3281947c8 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala @@ -52,7 +52,8 @@ abstract class NativeBroadcastJoinBase( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - broadcastSide: BroadcastSide) + broadcastSide: BroadcastSide, + isNullAwareAntiJoin: Boolean) extends BinaryExecNode with NativeSupports { @@ -174,6 +175,7 @@ abstract class NativeBroadcastJoinBase( .setJoinType(nativeJoinType) .setBroadcastSide(nativeBroadcastSide) .setCachedBuildHashMapId(cachedBuildHashMapId) + .setIsNullAwareAntiJoin(isNullAwareAntiJoin) .addAllOn(nativeJoinOn.asJava) pb.PhysicalPlanNode.newBuilder().setBroadcastJoin(broadcastJoinExec).build()