Skip to content

Commit 74591c1

Browse files
authored
[BUG] Fix bug with merge tasks that allows for tasks larger than max size allowed (#1882)
<img width="1186" alt="image" src="https://github.com/Eventual-Inc/Daft/assets/2550285/b44277b8-a768-4618-82ff-b50b3a61b8a3">
1 parent cf77fd2 commit 74591c1

2 files changed

Lines changed: 61 additions & 66 deletions

File tree

src/daft-scan/src/scan_task_iters.rs

Lines changed: 57 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -45,74 +45,69 @@ struct MergeByFileSize {
4545
accumulator: Option<ScanTaskRef>,
4646
}
4747

48+
impl MergeByFileSize {
49+
fn accumulator_ready(&self) -> bool {
50+
if let Some(acc) = &self.accumulator && let Some(acc_bytes) = acc.size_bytes() && acc_bytes >= self.min_size_bytes {
51+
true
52+
} else {
53+
false
54+
}
55+
}
56+
57+
fn can_merge(&self, other: &ScanTask) -> bool {
58+
let accumulator = self
59+
.accumulator
60+
.as_ref()
61+
.expect("accumulator should be populated");
62+
let child_matches_accumulator = other.partition_spec() == accumulator.partition_spec()
63+
&& other.file_format_config == accumulator.file_format_config
64+
&& other.schema == accumulator.schema
65+
&& other.storage_config == accumulator.storage_config
66+
&& other.pushdowns == accumulator.pushdowns;
67+
68+
let sum_smaller_than_max_size_bytes = if let Some(child_bytes) = other.size_bytes()
69+
&& let Some(accumulator_bytes) = accumulator.size_bytes() {child_bytes + accumulator_bytes <= self.max_size_bytes} else {false};
70+
71+
child_matches_accumulator && sum_smaller_than_max_size_bytes
72+
}
73+
}
74+
4875
impl Iterator for MergeByFileSize {
4976
type Item = DaftResult<ScanTaskRef>;
5077

5178
fn next(&mut self) -> Option<Self::Item> {
5279
loop {
53-
// Grabs the accumulator, leaving a `None` in its place
54-
let accumulator = self.accumulator.take();
55-
56-
match (self.iter.next(), accumulator) {
57-
// When no accumulator exists, trivially place the ScanTask into the accumulator
58-
(Some(Ok(child_item)), None) => {
59-
self.accumulator = Some(child_item);
60-
continue;
61-
}
62-
// When an accumulator exists, attempt a merge and yield the result
63-
(Some(Ok(child_item)), Some(accumulator)) => {
64-
// Whether or not the accumulator and the current item should be merged
65-
let should_merge = {
66-
let child_matches_accumulator = child_item.partition_spec()
67-
== accumulator.partition_spec()
68-
&& child_item.file_format_config == accumulator.file_format_config
69-
&& child_item.schema == accumulator.schema
70-
&& child_item.storage_config == accumulator.storage_config
71-
&& child_item.pushdowns == accumulator.pushdowns;
72-
let smaller_than_max_size_bytes = matches!(
73-
(child_item.size_bytes(), accumulator.size_bytes()),
74-
(Some(child_item_size), Some(buffered_item_size)) if child_item_size + buffered_item_size <= self.max_size_bytes
75-
);
76-
child_matches_accumulator && smaller_than_max_size_bytes
77-
};
78-
79-
if should_merge {
80-
let merged_result = Some(Arc::new(
81-
ScanTask::merge(accumulator.as_ref(), child_item.as_ref())
82-
.expect("ScanTasks should be mergeable in MergeByFileSize"),
83-
));
84-
85-
// Whether or not we should immediately yield the merged result, or keep accumulating
86-
let should_yield = matches!(
87-
(child_item.size_bytes(), accumulator.size_bytes()),
88-
(Some(child_item_size), Some(buffered_item_size)) if child_item_size + buffered_item_size >= self.min_size_bytes
89-
);
90-
91-
// Either yield eagerly, or keep looping with a merged accumulator
92-
if should_yield {
93-
return Ok(merged_result).transpose();
94-
} else {
95-
self.accumulator = merged_result;
96-
continue;
97-
}
98-
} else {
99-
self.accumulator = Some(child_item);
100-
return Some(Ok(accumulator));
101-
}
102-
}
103-
// Bubble up errors from child iterator, making sure to replace the accumulator which we moved
104-
(Some(Err(e)), acc) => {
105-
self.accumulator = acc;
106-
return Some(Err(e));
107-
}
108-
// Iterator ran out of elements: ensure that we flush the last buffered ScanTask
109-
(None, Some(last_scan_task)) => {
110-
return Some(Ok(last_scan_task));
111-
}
112-
(None, None) => {
113-
return None;
114-
}
80+
if self.accumulator.is_none() {
81+
self.accumulator = match self.iter.next() {
82+
Some(Ok(item)) => Some(item),
83+
e @ Some(Err(_)) => return e,
84+
None => return None,
85+
};
86+
}
87+
88+
if self.accumulator_ready() {
89+
return self.accumulator.take().map(Ok);
11590
}
91+
92+
let next_item = match self.iter.next() {
93+
Some(Ok(item)) => item,
94+
e @ Some(Err(_)) => return e,
95+
None => return self.accumulator.take().map(Ok),
96+
};
97+
98+
if next_item.size_bytes().is_none() || !self.can_merge(&next_item) {
99+
return self.accumulator.replace(next_item).map(Ok);
100+
}
101+
102+
self.accumulator = Some(Arc::new(
103+
ScanTask::merge(
104+
self.accumulator
105+
.as_ref()
106+
.expect("accumulator should be populated"),
107+
next_item.as_ref(),
108+
)
109+
.expect("ScanTasks should be mergeable in MergeByFileSize"),
110+
));
116111
}
117112
}
118113
}

tests/io/test_merge_scan_tasks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ def test_merge_scan_task_exceed_max(csv_files):
4444

4545
@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")
4646
def test_merge_scan_task_below_max(csv_files):
47-
with override_merge_scan_tasks_configs(1, 20):
47+
with override_merge_scan_tasks_configs(21, 22):
4848
df = daft.read_csv(str(csv_files))
4949
assert (
5050
df.num_partitions() == 2
51-
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>20 bytes)"
51+
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>22 bytes)"
5252

5353

5454
@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")
5555
def test_merge_scan_task_above_min(csv_files):
56-
with override_merge_scan_tasks_configs(0, 40):
56+
with override_merge_scan_tasks_configs(19, 40):
5757
df = daft.read_csv(str(csv_files))
5858
assert (
5959
df.num_partitions() == 2
60-
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>0 bytes)"
60+
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>19 bytes)"
6161

6262

6363
@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")

0 commit comments

Comments
 (0)