|
2 | 2 |
|
3 | 3 | import functools |
4 | 4 | import sys |
5 | | -from collections.abc import Callable |
6 | | -from itertools import repeat |
7 | | -from typing import TYPE_CHECKING |
| 5 | +from collections.abc import Callable, Mapping |
| 6 | +from typing import TYPE_CHECKING, Any, cast |
8 | 7 |
|
9 | | -from xarray.core.dataarray import DataArray |
10 | 8 | from xarray.core.dataset import Dataset |
11 | 9 | from xarray.core.formatting import diff_treestructure |
12 | | -from xarray.core.treenode import NodePath, TreeNode |
| 10 | +from xarray.core.treenode import TreeNode, zip_subtrees |
13 | 11 |
|
14 | 12 | if TYPE_CHECKING: |
15 | 13 | from xarray.core.datatree import DataTree |
@@ -125,110 +123,55 @@ def map_over_datasets(func: Callable) -> Callable: |
125 | 123 | # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? |
126 | 124 |
|
127 | 125 | @functools.wraps(func) |
128 | | - def _map_over_datasets(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: |
| 126 | + def _map_over_datasets(*args) -> DataTree | tuple[DataTree, ...]: |
129 | 127 | """Internal function which maps func over every node in tree, returning a tree of the results.""" |
130 | 128 | from xarray.core.datatree import DataTree |
131 | 129 |
|
132 | | - all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ |
133 | | - a for a in kwargs.values() if isinstance(a, DataTree) |
134 | | - ] |
135 | | - |
136 | | - if len(all_tree_inputs) > 0: |
137 | | - first_tree, *other_trees = all_tree_inputs |
138 | | - else: |
139 | | - raise TypeError("Must pass at least one tree object") |
140 | | - |
141 | | - for other_tree in other_trees: |
142 | | - # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic |
143 | | - check_isomorphic( |
144 | | - first_tree, other_tree, require_names_equal=False, check_from_root=False |
145 | | - ) |
146 | | - |
147 | 130 | # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees |
148 | 131 | # We don't know which arguments are DataTrees so we zip all arguments together as iterables |
149 | 132 | # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return |
150 | | - out_data_objects = {} |
151 | | - args_as_tree_length_iterables = [ |
152 | | - a.subtree if isinstance(a, DataTree) else repeat(a) for a in args |
153 | | - ] |
154 | | - n_args = len(args_as_tree_length_iterables) |
155 | | - kwargs_as_tree_length_iterables = { |
156 | | - k: v.subtree if isinstance(v, DataTree) else repeat(v) |
157 | | - for k, v in kwargs.items() |
158 | | - } |
159 | | - for node_of_first_tree, *all_node_args in zip( |
160 | | - first_tree.subtree, |
161 | | - *args_as_tree_length_iterables, |
162 | | - *list(kwargs_as_tree_length_iterables.values()), |
163 | | - strict=False, |
164 | | - ): |
165 | | - node_args_as_datasetviews = [ |
166 | | - a.dataset if isinstance(a, DataTree) else a |
167 | | - for a in all_node_args[:n_args] |
168 | | - ] |
169 | | - node_kwargs_as_datasetviews = dict( |
170 | | - zip( |
171 | | - [k for k in kwargs_as_tree_length_iterables.keys()], |
172 | | - [ |
173 | | - v.dataset if isinstance(v, DataTree) else v |
174 | | - for v in all_node_args[n_args:] |
175 | | - ], |
176 | | - strict=True, |
177 | | - ) |
| 133 | + out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {} |
| 134 | + |
| 135 | + tree_args = [arg for arg in args if isinstance(arg, DataTree)] |
| 136 | + subtrees = zip_subtrees(*tree_args) |
| 137 | + |
| 138 | + for node_tree_args in subtrees: |
| 139 | + |
| 140 | + node_dataset_args = [arg.dataset for arg in node_tree_args] |
| 141 | + for i, arg in enumerate(args): |
| 142 | + if not isinstance(arg, DataTree): |
| 143 | + node_dataset_args.insert(i, arg) |
| 144 | + |
| 145 | + path = ( |
| 146 | + "/" |
| 147 | + if node_tree_args[0] is tree_args[0] |
| 148 | + else node_tree_args[0].relative_to(tree_args[0]) |
178 | 149 | ) |
179 | | - func_with_error_context = _handle_errors_with_path_context( |
180 | | - node_of_first_tree.path |
181 | | - )(func) |
182 | | - |
183 | | - if node_of_first_tree.has_data: |
184 | | - # call func on the data in this particular set of corresponding nodes |
185 | | - results = func_with_error_context( |
186 | | - *node_args_as_datasetviews, **node_kwargs_as_datasetviews |
187 | | - ) |
188 | | - elif node_of_first_tree.has_attrs: |
189 | | - # propagate attrs |
190 | | - results = node_of_first_tree.dataset |
191 | | - else: |
192 | | - # nothing to propagate so use fastpath to create empty node in new tree |
193 | | - results = None |
| 150 | + func_with_error_context = _handle_errors_with_path_context(path)(func) |
| 151 | + results = func_with_error_context(*node_dataset_args) |
194 | 152 |
|
195 | | - # TODO implement mapping over multiple trees in-place using if conditions from here on? |
196 | | - out_data_objects[node_of_first_tree.path] = results |
| 153 | + out_data_objects[path] = results |
197 | 154 |
|
198 | | - # Find out how many return values we received |
199 | 155 | num_return_values = _check_all_return_values(out_data_objects) |
200 | 156 |
|
201 | | - # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees |
202 | | - original_root_path = first_tree.path |
203 | | - result_trees = [] |
204 | | - for i in range(num_return_values): |
205 | | - out_tree_contents = {} |
206 | | - for n in first_tree.subtree: |
207 | | - p = n.path |
208 | | - if p in out_data_objects.keys(): |
209 | | - if isinstance(out_data_objects[p], tuple): |
210 | | - output_node_data = out_data_objects[p][i] |
211 | | - else: |
212 | | - output_node_data = out_data_objects[p] |
213 | | - else: |
214 | | - output_node_data = None |
215 | | - |
216 | | - # Discard parentage so that new trees don't include parents of input nodes |
217 | | - relative_path = str(NodePath(p).relative_to(original_root_path)) |
218 | | - relative_path = "/" if relative_path == "." else relative_path |
219 | | - out_tree_contents[relative_path] = output_node_data |
220 | | - |
221 | | - new_tree = DataTree.from_dict( |
222 | | - out_tree_contents, |
223 | | - name=first_tree.name, |
224 | | - ) |
225 | | - result_trees.append(new_tree) |
| 157 | + if num_return_values is None: |
| 158 | + out_data = cast(Mapping[str, Dataset | None], out_data_objects) |
| 159 | + return DataTree.from_dict(out_data, name=tree_args[0].name) |
226 | 160 |
|
227 | | - # If only one result then don't wrap it in a tuple |
228 | | - if len(result_trees) == 1: |
229 | | - return result_trees[0] |
230 | | - else: |
231 | | - return tuple(result_trees) |
| 161 | + out_data_tuples = cast( |
| 162 | + Mapping[str, tuple[Dataset | None, ...]], out_data_objects |
| 163 | + ) |
| 164 | + output_dicts: list[dict[str, Dataset | None]] = [ |
| 165 | + {} for _ in range(num_return_values) |
| 166 | + ] |
| 167 | + for path, outputs in out_data_tuples.items(): |
| 168 | + for output_dict, output in zip(output_dicts, outputs, strict=False): |
| 169 | + output_dict[path] = output |
| 170 | + |
| 171 | + return tuple( |
| 172 | + DataTree.from_dict(output_dict, name=tree_args[0].name) |
| 173 | + for output_dict in output_dicts |
| 174 | + ) |
232 | 175 |
|
233 | 176 | return _map_over_datasets |
234 | 177 |
|
@@ -260,62 +203,54 @@ def add_note(err: BaseException, msg: str) -> None: |
260 | 203 | err.add_note(msg) |
261 | 204 |
|
262 | 205 |
|
263 | | -def _check_single_set_return_values( |
264 | | - path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray] |
265 | | -): |
| 206 | +def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None: |
266 | 207 | """Check types returned from single evaluation of func, and return number of return values received from func.""" |
267 | | - if isinstance(obj, Dataset | DataArray): |
268 | | - return 1 |
269 | | - elif isinstance(obj, tuple): |
270 | | - for r in obj: |
271 | | - if not isinstance(r, Dataset | DataArray): |
272 | | - raise TypeError( |
273 | | - f"One of the results of calling func on datasets on the nodes at position {path_to_node} is " |
274 | | - f"of type {type(r)}, not Dataset or DataArray." |
275 | | - ) |
276 | | - return len(obj) |
277 | | - else: |
| 208 | + if isinstance(obj, None | Dataset): |
| 209 | + return None # no need to pack results |
| 210 | + |
| 211 | + if not isinstance(obj, tuple) or not all( |
| 212 | + isinstance(r, Dataset | None) for r in obj |
| 213 | + ): |
278 | 214 | raise TypeError( |
279 | | - f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not " |
280 | | - f"Dataset or DataArray, nor a tuple of such types." |
| 215 | + f"the result of calling func on the node at position is not a Dataset or None " |
| 216 | + f"or a tuple of such types: {obj!r}" |
281 | 217 | ) |
282 | 218 |
|
| 219 | + return len(obj) |
283 | 220 |
|
284 | | -def _check_all_return_values(returned_objects): |
285 | | - """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" |
286 | 221 |
|
287 | | - if all(r is None for r in returned_objects.values()): |
288 | | - raise TypeError( |
289 | | - "Called supplied function on all nodes but found a return value of None for" |
290 | | - "all of them." |
291 | | - ) |
| 222 | +def _check_all_return_values(returned_objects) -> int | None: |
| 223 | + """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" |
292 | 224 |
|
293 | 225 | result_data_objects = [ |
294 | | - (path_to_node, r) |
295 | | - for path_to_node, r in returned_objects.items() |
296 | | - if r is not None |
| 226 | + (path_to_node, r) for path_to_node, r in returned_objects.items() |
297 | 227 | ] |
298 | 228 |
|
299 | | - if len(result_data_objects) == 1: |
300 | | - # Only one node in the tree: no need to check consistency of results between nodes |
301 | | - path_to_node, result = result_data_objects[0] |
302 | | - num_return_values = _check_single_set_return_values(path_to_node, result) |
303 | | - else: |
304 | | - prev_path, _ = result_data_objects[0] |
305 | | - prev_num_return_values, num_return_values = None, None |
306 | | - for path_to_node, obj in result_data_objects[1:]: |
307 | | - num_return_values = _check_single_set_return_values(path_to_node, obj) |
308 | | - |
309 | | - if ( |
310 | | - num_return_values != prev_num_return_values |
311 | | - and prev_num_return_values is not None |
312 | | - ): |
| 229 | + first_path, result = result_data_objects[0] |
| 230 | + return_values = _check_single_set_return_values(first_path, result) |
| 231 | + |
| 232 | + for path_to_node, obj in result_data_objects[1:]: |
| 233 | + cur_return_values = _check_single_set_return_values(path_to_node, obj) |
| 234 | + |
| 235 | + if return_values != cur_return_values: |
| 236 | + if return_values is None: |
313 | 237 | raise TypeError( |
314 | | - f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return " |
315 | | - f"values, whereas calling func on the nodes at position {prev_path} instead returns " |
316 | | - f"{prev_num_return_values} separate return values." |
| 238 | + f"Calling func on the nodes at position {path_to_node} returns " |
| 239 | + f"a tuple of {cur_return_values} datasets, whereas calling func on the " |
| 240 | + f"nodes at position {first_path} instead returns a single dataset." |
| 241 | + ) |
| 242 | + elif cur_return_values is None: |
| 243 | + raise TypeError( |
| 244 | + f"Calling func on the nodes at position {path_to_node} returns " |
| 245 | + f"a single dataset, whereas calling func on the nodes at position " |
| 246 | + f"{first_path} instead returns a tuple of {return_values} datasets." |
| 247 | + ) |
| 248 | + else: |
| 249 | + raise TypeError( |
| 250 | + f"Calling func on the nodes at position {path_to_node} returns " |
| 251 | + f"a tuple of {cur_return_values} datasets, whereas calling func on " |
| 252 | + f"the nodes at position {first_path} instead returns a tuple of " |
| 253 | + f"{return_values} datasets." |
317 | 254 | ) |
318 | 255 |
|
319 | | - prev_path, prev_num_return_values = path_to_node, num_return_values |
320 | | - |
321 | | - return num_return_values |
| 256 | + return return_values |
0 commit comments