Skip to content

Commit 739573a

Browse files
committed
Re-implement map_over_datasets
The main changes: - It is implemented using zip_subtrees, which means it should properly handle DataTrees where the nodes are defined in a different order. - For simplicity, I removed handling of `**kwargs`, in order to preserve some flexibility for adding keyword arugments. - I removed automatic skipping of empty nodes, because there are almost assuredly cases where that would make sense. This could be restored with a option keyword arugment.
1 parent 4480e11 commit 739573a

4 files changed

Lines changed: 124 additions & 190 deletions

File tree

xarray/core/datatree.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,6 @@ def map_over_datasets(
13871387
self,
13881388
func: Callable,
13891389
*args: Iterable[Any],
1390-
**kwargs: Any,
13911390
) -> DataTree | tuple[DataTree, ...]:
13921391
"""
13931392
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
@@ -1406,8 +1405,6 @@ def map_over_datasets(
14061405
Function will not be applied to any nodes without datasets.
14071406
*args : tuple, optional
14081407
Positional arguments passed on to `func`.
1409-
**kwargs : Any
1410-
Keyword arguments passed on to `func`.
14111408
14121409
Returns
14131410
-------
@@ -1417,7 +1414,7 @@ def map_over_datasets(
14171414
# TODO this signature means that func has no way to know which node it is being called upon - change?
14181415

14191416
# TODO fix this typing error
1420-
return map_over_datasets(func)(self, *args, **kwargs)
1417+
return map_over_datasets(func)(self, *args)
14211418

14221419
def pipe(
14231420
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any

xarray/core/datatree_mapping.py

Lines changed: 78 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22

33
import functools
44
import sys
5-
from collections.abc import Callable
6-
from itertools import repeat
7-
from typing import TYPE_CHECKING
5+
from collections.abc import Callable, Mapping
6+
from typing import TYPE_CHECKING, Any, cast
87

9-
from xarray.core.dataarray import DataArray
108
from xarray.core.dataset import Dataset
119
from xarray.core.formatting import diff_treestructure
12-
from xarray.core.treenode import NodePath, TreeNode
10+
from xarray.core.treenode import TreeNode, zip_subtrees
1311

1412
if TYPE_CHECKING:
1513
from xarray.core.datatree import DataTree
@@ -125,110 +123,55 @@ def map_over_datasets(func: Callable) -> Callable:
125123
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
126124

127125
@functools.wraps(func)
128-
def _map_over_datasets(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
126+
def _map_over_datasets(*args) -> DataTree | tuple[DataTree, ...]:
129127
"""Internal function which maps func over every node in tree, returning a tree of the results."""
130128
from xarray.core.datatree import DataTree
131129

132-
all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
133-
a for a in kwargs.values() if isinstance(a, DataTree)
134-
]
135-
136-
if len(all_tree_inputs) > 0:
137-
first_tree, *other_trees = all_tree_inputs
138-
else:
139-
raise TypeError("Must pass at least one tree object")
140-
141-
for other_tree in other_trees:
142-
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
143-
check_isomorphic(
144-
first_tree, other_tree, require_names_equal=False, check_from_root=False
145-
)
146-
147130
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
148131
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
149132
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
150-
out_data_objects = {}
151-
args_as_tree_length_iterables = [
152-
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
153-
]
154-
n_args = len(args_as_tree_length_iterables)
155-
kwargs_as_tree_length_iterables = {
156-
k: v.subtree if isinstance(v, DataTree) else repeat(v)
157-
for k, v in kwargs.items()
158-
}
159-
for node_of_first_tree, *all_node_args in zip(
160-
first_tree.subtree,
161-
*args_as_tree_length_iterables,
162-
*list(kwargs_as_tree_length_iterables.values()),
163-
strict=False,
164-
):
165-
node_args_as_datasetviews = [
166-
a.dataset if isinstance(a, DataTree) else a
167-
for a in all_node_args[:n_args]
168-
]
169-
node_kwargs_as_datasetviews = dict(
170-
zip(
171-
[k for k in kwargs_as_tree_length_iterables.keys()],
172-
[
173-
v.dataset if isinstance(v, DataTree) else v
174-
for v in all_node_args[n_args:]
175-
],
176-
strict=True,
177-
)
133+
out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {}
134+
135+
tree_args = [arg for arg in args if isinstance(arg, DataTree)]
136+
subtrees = zip_subtrees(*tree_args)
137+
138+
for node_tree_args in subtrees:
139+
140+
node_dataset_args = [arg.dataset for arg in node_tree_args]
141+
for i, arg in enumerate(args):
142+
if not isinstance(arg, DataTree):
143+
node_dataset_args.insert(i, arg)
144+
145+
path = (
146+
"/"
147+
if node_tree_args[0] is tree_args[0]
148+
else node_tree_args[0].relative_to(tree_args[0])
178149
)
179-
func_with_error_context = _handle_errors_with_path_context(
180-
node_of_first_tree.path
181-
)(func)
182-
183-
if node_of_first_tree.has_data:
184-
# call func on the data in this particular set of corresponding nodes
185-
results = func_with_error_context(
186-
*node_args_as_datasetviews, **node_kwargs_as_datasetviews
187-
)
188-
elif node_of_first_tree.has_attrs:
189-
# propagate attrs
190-
results = node_of_first_tree.dataset
191-
else:
192-
# nothing to propagate so use fastpath to create empty node in new tree
193-
results = None
150+
func_with_error_context = _handle_errors_with_path_context(path)(func)
151+
results = func_with_error_context(*node_dataset_args)
194152

195-
# TODO implement mapping over multiple trees in-place using if conditions from here on?
196-
out_data_objects[node_of_first_tree.path] = results
153+
out_data_objects[path] = results
197154

198-
# Find out how many return values we received
199155
num_return_values = _check_all_return_values(out_data_objects)
200156

201-
# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
202-
original_root_path = first_tree.path
203-
result_trees = []
204-
for i in range(num_return_values):
205-
out_tree_contents = {}
206-
for n in first_tree.subtree:
207-
p = n.path
208-
if p in out_data_objects.keys():
209-
if isinstance(out_data_objects[p], tuple):
210-
output_node_data = out_data_objects[p][i]
211-
else:
212-
output_node_data = out_data_objects[p]
213-
else:
214-
output_node_data = None
215-
216-
# Discard parentage so that new trees don't include parents of input nodes
217-
relative_path = str(NodePath(p).relative_to(original_root_path))
218-
relative_path = "/" if relative_path == "." else relative_path
219-
out_tree_contents[relative_path] = output_node_data
220-
221-
new_tree = DataTree.from_dict(
222-
out_tree_contents,
223-
name=first_tree.name,
224-
)
225-
result_trees.append(new_tree)
157+
if num_return_values is None:
158+
out_data = cast(Mapping[str, Dataset | None], out_data_objects)
159+
return DataTree.from_dict(out_data, name=tree_args[0].name)
226160

227-
# If only one result then don't wrap it in a tuple
228-
if len(result_trees) == 1:
229-
return result_trees[0]
230-
else:
231-
return tuple(result_trees)
161+
out_data_tuples = cast(
162+
Mapping[str, tuple[Dataset | None, ...]], out_data_objects
163+
)
164+
output_dicts: list[dict[str, Dataset | None]] = [
165+
{} for _ in range(num_return_values)
166+
]
167+
for path, outputs in out_data_tuples.items():
168+
for output_dict, output in zip(output_dicts, outputs, strict=False):
169+
output_dict[path] = output
170+
171+
return tuple(
172+
DataTree.from_dict(output_dict, name=tree_args[0].name)
173+
for output_dict in output_dicts
174+
)
232175

233176
return _map_over_datasets
234177

@@ -260,62 +203,54 @@ def add_note(err: BaseException, msg: str) -> None:
260203
err.add_note(msg)
261204

262205

263-
def _check_single_set_return_values(
264-
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
265-
):
206+
def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None:
266207
"""Check types returned from single evaluation of func, and return number of return values received from func."""
267-
if isinstance(obj, Dataset | DataArray):
268-
return 1
269-
elif isinstance(obj, tuple):
270-
for r in obj:
271-
if not isinstance(r, Dataset | DataArray):
272-
raise TypeError(
273-
f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
274-
f"of type {type(r)}, not Dataset or DataArray."
275-
)
276-
return len(obj)
277-
else:
208+
if isinstance(obj, None | Dataset):
209+
return None # no need to pack results
210+
211+
if not isinstance(obj, tuple) or not all(
212+
isinstance(r, Dataset | None) for r in obj
213+
):
278214
raise TypeError(
279-
f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
280-
f"Dataset or DataArray, nor a tuple of such types."
215+
f"the result of calling func on the node at position is not a Dataset or None "
216+
f"or a tuple of such types: {obj!r}"
281217
)
282218

219+
return len(obj)
283220

284-
def _check_all_return_values(returned_objects):
285-
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""
286221

287-
if all(r is None for r in returned_objects.values()):
288-
raise TypeError(
289-
"Called supplied function on all nodes but found a return value of None for"
290-
"all of them."
291-
)
222+
def _check_all_return_values(returned_objects) -> int | None:
223+
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""
292224

293225
result_data_objects = [
294-
(path_to_node, r)
295-
for path_to_node, r in returned_objects.items()
296-
if r is not None
226+
(path_to_node, r) for path_to_node, r in returned_objects.items()
297227
]
298228

299-
if len(result_data_objects) == 1:
300-
# Only one node in the tree: no need to check consistency of results between nodes
301-
path_to_node, result = result_data_objects[0]
302-
num_return_values = _check_single_set_return_values(path_to_node, result)
303-
else:
304-
prev_path, _ = result_data_objects[0]
305-
prev_num_return_values, num_return_values = None, None
306-
for path_to_node, obj in result_data_objects[1:]:
307-
num_return_values = _check_single_set_return_values(path_to_node, obj)
308-
309-
if (
310-
num_return_values != prev_num_return_values
311-
and prev_num_return_values is not None
312-
):
229+
first_path, result = result_data_objects[0]
230+
return_values = _check_single_set_return_values(first_path, result)
231+
232+
for path_to_node, obj in result_data_objects[1:]:
233+
cur_return_values = _check_single_set_return_values(path_to_node, obj)
234+
235+
if return_values != cur_return_values:
236+
if return_values is None:
313237
raise TypeError(
314-
f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
315-
f"values, whereas calling func on the nodes at position {prev_path} instead returns "
316-
f"{prev_num_return_values} separate return values."
238+
f"Calling func on the nodes at position {path_to_node} returns "
239+
f"a tuple of {cur_return_values} datasets, whereas calling func on the "
240+
f"nodes at position {first_path} instead returns a single dataset."
241+
)
242+
elif cur_return_values is None:
243+
raise TypeError(
244+
f"Calling func on the nodes at position {path_to_node} returns "
245+
f"a single dataset, whereas calling func on the nodes at position "
246+
f"{first_path} instead returns a tuple of {return_values} datasets."
247+
)
248+
else:
249+
raise TypeError(
250+
f"Calling func on the nodes at position {path_to_node} returns "
251+
f"a tuple of {cur_return_values} datasets, whereas calling func on "
252+
f"the nodes at position {first_path} instead returns a tuple of "
253+
f"{return_values} datasets."
317254
)
318255

319-
prev_path, prev_num_return_values = path_to_node, num_return_values
320-
321-
return num_return_values
256+
return return_values

xarray/core/treenode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,9 @@ def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]:
791791
------
792792
Tuples of matching subtrees.
793793
"""
794+
if not trees:
795+
raise TypeError("Must pass at least one tree object")
796+
794797
# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode
795798
queue = collections.deque([trees])
796799

0 commit comments

Comments
 (0)