Fix foreach join inputs out-of-order for >4 branches#2974
Fix foreach join inputs out-of-order for >4 branches#2974
Conversation
When a foreach has more than 4 branches, _init_data uses TaskDataStoreSet which calls get_task_datastores(). That function built its return list by iterating over a Python set (done_attempts & latest_started_attempts), which has no defined order — so inputs at join steps arrived in arbitrary order rather than matching the original foreach list. Fix by sorting the result of get_task_datastores() to match the pathspecs order before returning. Also implement the long-standing TODO in Inputs.__init__ to sort by _foreach_stack index as a defensive layer. Add a comment in _init_data explaining that both code paths (sequential and TaskDataStoreSet) now guarantee pathspec-order preservation. Add ForeachJoinOrderTest which uses a non-monotonic 6-element foreach list (>4 to trigger TaskDataStoreSet) and asserts that inputs[i].index == i and values match the original array positionally without sorting.
Drop sorted() from BasicForeachTest and WideForeachTest join checks. Both already use range(N) as the foreach list, so removing sorted() turns them into direct ordering assertions — no extra test file needed.
Sorting by _foreach_stack[-1].index is the semantically correct fix: it uses the ground truth stored in the task itself, so ordering is guaranteed regardless of how runtime.py orders input_paths or how get_task_datastores() returns results. _foreach_stack is already prefetched in the foreach join path so the sort is cheap.
Greptile SummaryThis PR fixes a real ordering bug in Metaflow's foreach join step: when more than 4 branches are present, Key changes:
Minor concerns:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[_init_data called\nwith input_paths] --> B{len input_paths > 4?}
B -- Yes --> C[Create TaskDataStoreSet\nwith pathspecs]
C --> D{join_type == foreach?}
D -- Yes --> E[Prefetch _foreach_stack\n_iteration_stack\n_foreach_num_splits\n_foreach_var]
D -- No --> F[No prefetch]
E --> G[Iterate set → ds_list\norder is ARBITRARY]
F --> G
G --> H{join_type == foreach?\nnew in this PR}
H -- Yes --> I[ds_list.sort by\n_foreach_stack-1.index\nCHEAP: already prefetched]
H -- No --> J[Keep arbitrary order]
I --> K[Validate len ds_list\n== len input_paths]
J --> K
B -- No, ≤4 --> L[Sequential load\nfor each input_path]
L --> M[Append to ds_list\nin input_paths order]
M --> K
K --> N[Return ordered ds_list\nto join step]
Last reviewed commit: bbc50a2 |
Uses a 6-element non-monotonic array [5, 3, 1, 4, 2, 6] (>4 to trigger the TaskDataStoreSet code path) and asserts both inputs[i].index == i and that values match the original array positionally. This catches partial-sort regressions that monotonic tests would miss.
| if join_type == "foreach": | ||
| # get_task_datastores() iterates over a set internally, so the | ||
| # returned order is arbitrary. Sort by foreach index (the ground | ||
| # truth stored in the task itself) so that inputs always arrive | ||
| # at join steps in the same order as the original foreach list. | ||
| # _foreach_stack is already prefetched above, so this is cheap. | ||
| ds_list.sort(key=lambda ds: ds["_foreach_stack"][-1].index) |
There was a problem hiding this comment.
<=4-branch path left unsorted
The sort is applied only inside the len(input_paths) > 4 branch. The sequential path (≤ 4 inputs, lines 287–308) iterates input_paths in whatever order it is passed and has no corresponding sort-by-index step.
If any caller or metadata backend ever delivers input_paths in a non-foreach-index order for the small case, inputs at the join step would still arrive out of order. The PR description asserts this path is safe today because input_paths arrives in the correct order, but there's no in-code guarantee or comment to that effect for future readers.
Consider either:
- Adding a parallel
sortafter line 308 (cheap, always correct), or - Adding an explicit comment documenting why the sequential path is guaranteed to be ordered (e.g., "input_paths is always provided in foreach-index order by the runtime scheduler"):
else:
# initialize directly in the single input case.
# input_paths is provided by the runtime in foreach-index order,
# so no explicit sort is needed here (unlike the TaskDataStoreSet
# path above which iterates a set internally).
ds_list = []…ions Replace range(32) with a shuffled 32-element list so that any partial-sort bug in foreach join ordering is immediately visible. Also removes the separate NonMonotonicForeachTest since this test now provides the same coverage.
| def join(self, inputs): | ||
| got = sorted([inp.my_input for inp in inputs]) | ||
| got = [inp.my_input for inp in inputs] | ||
| assert_equals(list(range(1200)), got) |
There was a problem hiding this comment.
Ordering assertion relies on undocumented value == index coincidence
Removing sorted() here only turns this into a meaningful ordering assertion because self.arr = range(1200) means each branch's self.input == self.index. If a future developer changes self.arr to any other sequence (e.g., a non-monotonic list to match BasicForeachTest), the ordering assertion would break silently unless the join assertion is also updated. Without a comment explaining this, the test looks equivalent to its previous sorted() form.
Consider adding an explanatory comment to make the invariant explicit:
| def join(self, inputs): | |
| got = sorted([inp.my_input for inp in inputs]) | |
| got = [inp.my_input for inp in inputs] | |
| assert_equals(list(range(1200)), got) | |
| @steps(0, ["foreach-join-small"], required=True) | |
| def join(self, inputs): | |
| # arr is range(1200), so value == index for each branch — this makes | |
| # the assertion an ordering test: any permutation of inputs yields a | |
| # got[i] != i mismatch. | |
| got = [inp.my_input for inp in inputs] | |
| assert_equals(list(range(1200)), got) |
| if join_type == "foreach": | ||
| # get_task_datastores() iterates over a set internally, so the | ||
| # returned order is arbitrary. Sort by foreach index (the ground | ||
| # truth stored in the task itself) so that inputs always arrive | ||
| # at join steps in the same order as the original foreach list. | ||
| # _foreach_stack is already prefetched above, so this is cheap. | ||
| ds_list.sort(key=lambda ds: ds["_foreach_stack"][-1].index) |
There was a problem hiding this comment.
ForeachFrame.index defaults to None — sort may crash on unexpected data
ForeachFrame is defined in tuple_util.py as a namedtuple_with_defaults where all fields, including index, default to None. In Python 3, comparing None values during a sort raises TypeError: '<' not supported between instances of 'NoneType' and 'int'.
In normal foreach execution index is always populated with split_index (an int) at task.py:410, so this is safe in practice. However, if a task's _foreach_stack artifact was written by a task in an unusual retry state, or deserialized from an older Metaflow version where the field was absent, the sort would crash with an unhelpful TypeError instead of a descriptive error.
Consider adding a guard to surface a clear error message in that edge case, for example by checking stack[-1].index is not None before sorting and raising a MetaflowDataMissing if the invariant is violated.
Summary
get_task_datastores()builds its return list by iterating a Python set, so inputs at foreach join steps arrived in arbitrary order rather than matching the original foreach list. Only affects foreach with >4 branches (the threshold whereTaskDataStoreSetis used instead of sequential loading).Fix: after loading datastores for a foreach join, sort by
_foreach_stack[-1].index— the ground truth stored in each task._foreach_stackis already prefetched in this code path so the sort is cheap.Also removes
sorted()fromBasicForeachTestandWideForeachTestjoin checks, turning them into ordering assertions.BasicForeachTestnow uses a shuffled (non-monotonic) 32-element array so that any partial-sort regression is immediately visible.