Skip to content

Fix foreach join inputs out-of-order for >4 branches#2974

Open
npow wants to merge 6 commits intomasterfrom
worktree-fix-foreach-join-ordering
Open

Fix foreach join inputs out-of-order for >4 branches#2974
npow wants to merge 6 commits intomasterfrom
worktree-fix-foreach-join-ordering

Conversation

@npow
Copy link
Collaborator

@npow npow commented Mar 6, 2026

Summary

get_task_datastores() builds its return list by iterating a Python set, so inputs at foreach join steps arrived in arbitrary order rather than matching the original foreach list. Only affects foreach with >4 branches (the threshold where TaskDataStoreSet is used instead of sequential loading).

Fix: after loading datastores for a foreach join, sort by _foreach_stack[-1].index — the ground truth stored in each task. _foreach_stack is already prefetched in this code path so the sort is cheap.

Also removes sorted() from BasicForeachTest and WideForeachTest join checks, turning them into ordering assertions. BasicForeachTest now uses a shuffled (non-monotonic) 32-element array so that any partial-sort regression is immediately visible.

Nissan Pow added 4 commits March 6, 2026 12:42
When a foreach has more than 4 branches, _init_data uses TaskDataStoreSet
which calls get_task_datastores(). That function built its return list by
iterating over a Python set (done_attempts & latest_started_attempts), which
has no defined order — so inputs at join steps arrived in arbitrary order
rather than matching the original foreach list.

Fix by sorting the result of get_task_datastores() to match the pathspecs
order before returning. Also implement the long-standing TODO in
Inputs.__init__ to sort by _foreach_stack index as a defensive layer.

Add a comment in _init_data explaining that both code paths (sequential and
TaskDataStoreSet) now guarantee pathspec-order preservation.

Add ForeachJoinOrderTest which uses a non-monotonic 6-element foreach list
(>4 to trigger TaskDataStoreSet) and asserts that inputs[i].index == i and
values match the original array positionally without sorting.
Drop sorted() from BasicForeachTest and WideForeachTest join checks.
Both already use range(N) as the foreach list, so removing sorted()
turns them into direct ordering assertions — no extra test file needed.
Sorting by _foreach_stack[-1].index is the semantically correct fix:
it uses the ground truth stored in the task itself, so ordering is
guaranteed regardless of how runtime.py orders input_paths or how
get_task_datastores() returns results. _foreach_stack is already
prefetched in the foreach join path so the sort is cheap.
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 6, 2026

Greptile Summary

This PR fixes a real ordering bug in Metaflow's foreach join step: when more than 4 branches are present, TaskDataStoreSet iterates an internal Python set, returning datastores in arbitrary order. The fix sorts ds_list by the foreach index stored in each task's _foreach_stack artifact immediately after iteration, scoped only to the >4-branch foreach join path where the bug manifests. This leverages the already-prefetched _foreach_stack artifact, so the sort adds no extra I/O overhead.

Key changes:

  • metaflow/task.py: Adds a sort-by-foreach-index step in the >4-branch foreach join path, with a clear comment explaining the rationale.
  • metaflow/datastore/inputs.py: Removes the long-standing # TODO sort by foreach index comment, now resolved.
  • test/core/tests/basic_foreach.py: Replaces monotonic range(32) with a 32-element non-monotonic permutation and drops sorted() from the join check, creating a strong regression test.
  • test/core/tests/wide_foreach.py: Removes sorted() from the 1200-branch join check, turning it into an ordering assertion (effective because the monotonic array means value equals index for every branch).

Minor concerns:

  • The ForeachFrame namedtuple's index field defaults to None; hitting that default in the sort key would surface as an opaque TypeError rather than a meaningful diagnostic.
  • WideForeachTest lacks a comment explaining why the monotonic array is sufficient as an ordering assertion, making the intent fragile for future maintainers.

Confidence Score: 4/5

  • Safe to merge — the fix is correctly scoped and well-tested, with only minor documentation and defensive-coding gaps.
  • The core fix is sound: sort is applied only in the correct code path (>4 inputs, join_type == "foreach"), the sort key uses the authoritative _foreach_stack[-1].index field, and the prefetch ensures no extra I/O. BasicForeachTest now provides a strong non-monotonic regression test. Two minor concerns prevent a 5: (1) the sort lambda has no guard against a None index (the ForeachFrame default), which would raise a cryptic TypeError on unexpected data; (2) WideForeachTest relies on an undocumented value == index coincidence for its ordering semantics.
  • metaflow/task.py (sort key defensiveness) and test/core/tests/wide_foreach.py (undocumented ordering assumption).

Important Files Changed

Filename Overview
metaflow/task.py Adds a sort-by-_foreach_stack[-1].index after TaskDataStoreSet iteration for foreach joins. Fix is correctly scoped to join_type == "foreach" and the >4 branch, and aligns with the pre-existing prefetch of _foreach_stack. Minor concern: ForeachFrame.index defaults to None so an unexpected None index would cause an opaque TypeError instead of a meaningful error message.
metaflow/datastore/inputs.py Removes the stale # TODO sort by foreach index comment, which is now resolved by the sort added in task.py.
test/core/tests/basic_foreach.py Replaces monotonic range(32) with a 32-element non-monotonic permutation and removes sorted() from the join assertion, making this a strong ordering regression test that triggers the >4 code path via TaskDataStoreSet.
test/core/tests/wide_foreach.py Removes sorted() from the join assertion. The test is an implicit ordering check only because self.arr = range(1200) means value == index for each branch — this invariant is not documented, making the intent fragile for future maintainers.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[_init_data called\nwith input_paths] --> B{len input_paths > 4?}

    B -- Yes --> C[Create TaskDataStoreSet\nwith pathspecs]
    C --> D{join_type == foreach?}
    D -- Yes --> E[Prefetch _foreach_stack\n_iteration_stack\n_foreach_num_splits\n_foreach_var]
    D -- No --> F[No prefetch]
    E --> G[Iterate set → ds_list\norder is ARBITRARY]
    F --> G
    G --> H{join_type == foreach?\nnew in this PR}
    H -- Yes --> I[ds_list.sort by\n_foreach_stack-1.index\nCHEAP: already prefetched]
    H -- No --> J[Keep arbitrary order]
    I --> K[Validate len ds_list\n== len input_paths]
    J --> K

    B -- No, ≤4 --> L[Sequential load\nfor each input_path]
    L --> M[Append to ds_list\nin input_paths order]
    M --> K

    K --> N[Return ordered ds_list\nto join step]
Loading

Last reviewed commit: bbc50a2

Uses a 6-element non-monotonic array [5, 3, 1, 4, 2, 6] (>4 to trigger
the TaskDataStoreSet code path) and asserts both inputs[i].index == i
and that values match the original array positionally. This catches
partial-sort regressions that monotonic tests would miss.
Comment on lines +274 to +280
if join_type == "foreach":
# get_task_datastores() iterates over a set internally, so the
# returned order is arbitrary. Sort by foreach index (the ground
# truth stored in the task itself) so that inputs always arrive
# at join steps in the same order as the original foreach list.
# _foreach_stack is already prefetched above, so this is cheap.
ds_list.sort(key=lambda ds: ds["_foreach_stack"][-1].index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<=4-branch path left unsorted

The sort is applied only inside the len(input_paths) > 4 branch. The sequential path (≤ 4 inputs, lines 287–308) iterates input_paths in whatever order it is passed and has no corresponding sort-by-index step.

If any caller or metadata backend ever delivers input_paths in a non-foreach-index order for the small case, inputs at the join step would still arrive out of order. The PR description asserts this path is safe today because input_paths arrives in the correct order, but there's no in-code guarantee or comment to that effect for future readers.

Consider either:

  • Adding a parallel sort after line 308 (cheap, always correct), or
  • Adding an explicit comment documenting why the sequential path is guaranteed to be ordered (e.g., "input_paths is always provided in foreach-index order by the runtime scheduler"):
        else:
            # initialize directly in the single input case.
            # input_paths is provided by the runtime in foreach-index order,
            # so no explicit sort is needed here (unlike the TaskDataStoreSet
            # path above which iterates a set internally).
            ds_list = []

…ions

Replace range(32) with a shuffled 32-element list so that any
partial-sort bug in foreach join ordering is immediately visible.
Also removes the separate NonMonotonicForeachTest since this test
now provides the same coverage.
Comment on lines 27 to 29
def join(self, inputs):
got = sorted([inp.my_input for inp in inputs])
got = [inp.my_input for inp in inputs]
assert_equals(list(range(1200)), got)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ordering assertion relies on undocumented value == index coincidence

Removing sorted() here only turns this into a meaningful ordering assertion because self.arr = range(1200) means each branch's self.input == self.index. If a future developer changes self.arr to any other sequence (e.g., a non-monotonic list to match BasicForeachTest), the ordering assertion would break silently unless the join assertion is also updated. Without a comment explaining this, the test looks equivalent to its previous sorted() form.

Consider adding an explanatory comment to make the invariant explicit:

Suggested change
def join(self, inputs):
got = sorted([inp.my_input for inp in inputs])
got = [inp.my_input for inp in inputs]
assert_equals(list(range(1200)), got)
@steps(0, ["foreach-join-small"], required=True)
def join(self, inputs):
# arr is range(1200), so value == index for each branch — this makes
# the assertion an ordering test: any permutation of inputs yields a
# got[i] != i mismatch.
got = [inp.my_input for inp in inputs]
assert_equals(list(range(1200)), got)

Comment on lines +274 to +280
if join_type == "foreach":
# get_task_datastores() iterates over a set internally, so the
# returned order is arbitrary. Sort by foreach index (the ground
# truth stored in the task itself) so that inputs always arrive
# at join steps in the same order as the original foreach list.
# _foreach_stack is already prefetched above, so this is cheap.
ds_list.sort(key=lambda ds: ds["_foreach_stack"][-1].index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ForeachFrame.index defaults to None — sort may crash on unexpected data

ForeachFrame is defined in tuple_util.py as a namedtuple_with_defaults where all fields, including index, default to None. In Python 3, comparing None values during a sort raises TypeError: '<' not supported between instances of 'NoneType' and 'int'.

In normal foreach execution index is always populated with split_index (an int) at task.py:410, so this is safe in practice. However, if a task's _foreach_stack artifact was written by a task in an unusual retry state, or deserialized from an older Metaflow version where the field was absent, the sort would crash with an unhelpful TypeError instead of a descriptive error.

Consider adding a guard to surface a clear error message in that edge case, for example by checking stack[-1].index is not None before sorting and raising a MetaflowDataMissing if the invariant is violated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant