Skip to content

Commit 84961e6

Browse files
keewisdcherianIllviljanmax-sixty
authored
keep attrs in xarray.where (#4687)
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>
1 parent d3b6aa6 commit 84961e6

4 files changed

Lines changed: 24 additions & 6 deletions

File tree

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ New Features
2323
~~~~~~~~~~~~
2424
- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`).
2525
By `Jimmy Westling <https://github.com/illviljan>`_.
26-
26+
- ``keep_attrs`` support for :py:func:`where` (:issue:`4141`, :issue:`4682`, :pull:`4687`).
27+
By `Justus Magin <https://github.com/keewis>`_.
2728
- Enable the limit option for dask array in the following methods :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` (:issue:`6112`)
2829
By `Joseph Nowak <https://github.com/josephnowak>`_.
2930

xarray/core/computation.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1727,7 +1727,7 @@ def dot(*arrays, dims=None, **kwargs):
17271727
return result.transpose(*all_dims, missing_dims="ignore")
17281728

17291729

1730-
def where(cond, x, y):
1730+
def where(cond, x, y, keep_attrs=None):
17311731
"""Return elements from `x` or `y` depending on `cond`.
17321732
17331733
Performs xarray-like broadcasting across input arguments.
@@ -1743,6 +1743,8 @@ def where(cond, x, y):
17431743
values to choose from where `cond` is True
17441744
y : scalar, array, Variable, DataArray or Dataset
17451745
values to choose from where `cond` is False
1746+
keep_attrs : bool or str or callable, optional
1747+
How to treat attrs. If True, keep the attrs of `x`.
17461748
17471749
Returns
17481750
-------
@@ -1808,6 +1810,14 @@ def where(cond, x, y):
18081810
Dataset.where, DataArray.where :
18091811
equivalent methods
18101812
"""
1813+
if keep_attrs is None:
1814+
keep_attrs = _get_keep_attrs(default=False)
1815+
1816+
if keep_attrs is True:
1817+
# keep the attributes of x, the second parameter, by default to
1818+
# be consistent with the `where` method of `DataArray` and `Dataset`
1819+
keep_attrs = lambda attrs, context: attrs[1]
1820+
18111821
# alignment for three arguments is complicated, so don't support it yet
18121822
return apply_ufunc(
18131823
duck_array_ops.where,
@@ -1817,6 +1827,7 @@ def where(cond, x, y):
18171827
join="exact",
18181828
dataset_join="exact",
18191829
dask="allowed",
1830+
keep_attrs=keep_attrs,
18201831
)
18211832

18221833

xarray/tests/test_computation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1922,6 +1922,15 @@ def test_where() -> None:
19221922
assert_identical(expected, actual)
19231923

19241924

1925+
def test_where_attrs() -> None:
1926+
cond = xr.DataArray([True, False], dims="x", attrs={"attr": "cond"})
1927+
x = xr.DataArray([1, 1], dims="x", attrs={"attr": "x"})
1928+
y = xr.DataArray([0, 0], dims="x", attrs={"attr": "y"})
1929+
actual = xr.where(cond, x, y, keep_attrs=True)
1930+
expected = xr.DataArray([1, 0], dims="x", attrs={"attr": "x"})
1931+
assert_identical(expected, actual)
1932+
1933+
19251934
@pytest.mark.parametrize("use_dask", [True, False])
19261935
@pytest.mark.parametrize("use_datetime", [True, False])
19271936
def test_polyval(use_dask, use_datetime) -> None:

xarray/tests/test_units.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2429,10 +2429,7 @@ def test_binary_operations(self, func, dtype):
24292429
(
24302430
pytest.param(operator.lt, id="less_than"),
24312431
pytest.param(operator.ge, id="greater_equal"),
2432-
pytest.param(
2433-
operator.eq,
2434-
id="equal",
2435-
),
2432+
pytest.param(operator.eq, id="equal"),
24362433
),
24372434
)
24382435
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)