Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions datafusion/core/tests/physical_optimizer/enforce_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use datafusion_physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion_physical_plan::execution_plan::ExecutionPlan;
use datafusion_physical_plan::expressions::col;
use datafusion_physical_plan::filter::FilterExec;
Expand Down Expand Up @@ -3471,3 +3472,47 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> {

Ok(())
}

#[test]
fn test_replace_order_preserving_variants_with_fetch() -> Result<()> {
// Create a base plan
let parquet_exec = parquet_exec();

let sort_expr = PhysicalSortExpr {
expr: Arc::new(Column::new("id", 0)),
options: SortOptions::default(),
};

let ordering = LexOrdering::new(vec![sort_expr]);

// Create a SortPreservingMergeExec with fetch=5
let spm_exec = Arc::new(
SortPreservingMergeExec::new(ordering, parquet_exec.clone()).with_fetch(Some(5)),
);

// Create distribution context
let dist_context = DistributionContext::new(
spm_exec,
true,
vec![DistributionContext::new(parquet_exec, false, vec![])],
);

// Apply the function
let result = replace_order_preserving_variants(dist_context)?;

// Verify the plan was transformed to CoalescePartitionsExec
result
.plan
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
.expect("Expected CoalescePartitionsExec");

// Verify fetch was preserved
assert_eq!(
result.plan.fetch(),
Some(5),
"Fetch value was not preserved after transformation"
);

Ok(())
}
7 changes: 5 additions & 2 deletions datafusion/physical-optimizer/src/enforce_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ fn remove_dist_changing_operators(
/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2",
/// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet",
/// ```
fn replace_order_preserving_variants(
pub fn replace_order_preserving_variants(
mut context: DistributionContext,
) -> Result<DistributionContext> {
context.children = context
Expand All @@ -1035,7 +1035,10 @@ fn replace_order_preserving_variants(

if is_sort_preserving_merge(&context.plan) {
let child_plan = Arc::clone(&context.children[0].plan);
context.plan = Arc::new(CoalescePartitionsExec::new(child_plan));
// It's safe to unwrap because `CoalescePartitionsExec` supports `fetch`.
context.plan = CoalescePartitionsExec::new(child_plan)
.with_fetch(context.plan.fetch())
.unwrap();
return Ok(context);
} else if let Some(repartition) =
context.plan.as_any().downcast_ref::<RepartitionExec>()
Expand Down