Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions metaflow/datastore/flow_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,17 @@ def get_task_datastores(
)
for v in latest_to_fetch
]
if pathspecs:
# The set operations above (latest_started_attempts & done_attempts)
# discard the original pathspecs ordering. When the caller provides
# pathspecs, they expect results back in that same order -- e.g. foreach
# join inputs must arrive in split order. Sort to restore it.
# pathspec format: run_id/step_name/task_id[/attempt]
position = {
(ps.split("/")[1], ps.split("/")[2]): i
for i, ps in enumerate(pathspecs)
}
latest_to_fetch.sort(key=lambda v: position[v[1], v[2]])
return list(itertools.starmap(self.get_task_datastore, latest_to_fetch))

def get_task_datastore(
Expand Down
1 change: 0 additions & 1 deletion metaflow/datastore/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ class Inputs(object):
"""

def __init__(self, flows):
# TODO sort by foreach index
self.flows = list(flows)
for flow in self.flows:
setattr(self, flow._current_step, flow)
Expand Down
76 changes: 73 additions & 3 deletions test/core/tests/basic_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,41 @@ class BasicForeachTest(MetaflowTest):
@steps(0, ["foreach-split"], required=True)
def split(self):
self.my_index = None
self.arr = range(32)
# Non-monotonic to catch foreach join ordering bugs
self.arr = [
26,
5,
10,
15,
25,
11,
22,
6,
19,
12,
16,
9,
28,
14,
24,
20,
30,
1,
13,
18,
2,
17,
21,
3,
29,
4,
27,
31,
8,
23,
0,
7,
]

@steps(0, ["foreach-inner"], required=True)
def inner(self):
Expand All @@ -30,8 +64,44 @@ def inner(self):

@steps(0, ["foreach-join"], required=True)
def join(self, inputs):
got = sorted([inp.my_input for inp in inputs])
assert_equals(list(range(32)), got)
got = [inp.my_input for inp in inputs]
assert_equals(
[
26,
5,
10,
15,
25,
11,
22,
6,
19,
12,
16,
9,
28,
14,
24,
20,
30,
1,
13,
18,
2,
17,
21,
3,
29,
4,
27,
31,
8,
23,
0,
7,
],
got,
)

@steps(1, ["all"])
def step_all(self):
Expand Down
2 changes: 1 addition & 1 deletion test/core/tests/wide_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def inner(self):

@steps(0, ["foreach-join-small"], required=True)
def join(self, inputs):
got = sorted([inp.my_input for inp in inputs])
got = [inp.my_input for inp in inputs]
assert_equals(list(range(1200)), got)
Comment on lines 27 to 29
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ordering assertion relies on undocumented value == index coincidence

Removing sorted() here only turns this into a meaningful ordering assertion because self.arr = range(1200) means each branch's self.input == self.index. If a future developer changes self.arr to any other sequence (e.g., a non-monotonic list to match BasicForeachTest), the ordering assertion would break silently unless the join assertion is also updated. Without a comment explaining this, the test looks equivalent to its previous sorted() form.

Consider adding an explanatory comment to make the invariant explicit:

Suggested change
def join(self, inputs):
got = sorted([inp.my_input for inp in inputs])
got = [inp.my_input for inp in inputs]
assert_equals(list(range(1200)), got)
@steps(0, ["foreach-join-small"], required=True)
def join(self, inputs):
# arr is range(1200), so value == index for each branch — this makes
# the assertion an ordering test: any permutation of inputs yields a
# got[i] != i mismatch.
got = [inp.my_input for inp in inputs]
assert_equals(list(range(1200)), got)


@steps(1, ["all"])
Expand Down
Loading