Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions xcdat/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Literal

import cf_xarray as cfxr # noqa: F401
import numpy as np
import xarray as xr

Expand Down Expand Up @@ -73,7 +74,7 @@ def get_dim_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> str | list[


def get_dim_coords(
obj: xr.Dataset | xr.DataArray, axis: CFAxisKey
obj: xr.Dataset | xr.DataArray, axis: CFAxisKey, multidim: bool = False
) -> xr.Dataset | xr.DataArray:
"""Gets the dimension coordinates for an axis.

Expand Down Expand Up @@ -117,10 +118,14 @@ def get_dim_coords(
----------
.. [1] https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#axes-and-coordinates
"""
# Get the object's index keys, with each being a dimension.
# NOTE: xarray does not include multidimensional coordinates as index keys.
# Example: ["lat", "lon", "time"]
index_keys = obj.indexes.keys()
if multidim:
# multidimensional coordinates cannot be indexes, use all coords
index_keys = list([y for x in obj.cf.coordinates.values() for y in x])
else:
# Get the object's index keys, with each being a dimension.
# NOTE: xarray does not include multidimensional coordinates as index keys.
# Example: ["lat", "lon", "time"]
index_keys = list(obj.indexes.keys())

# Attempt to map the axis it all of its coordinate variable(s) using the
# axis and coordinate names in the object attributes (if they are set).
Expand Down
43 changes: 28 additions & 15 deletions xcdat/regridder/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from xcdat.axis import CFAxisKey, get_coords_by_name, get_dim_coords
from xcdat.bounds import create_bounds
from xcdat.regridder import regrid2, xesmf, xgcm
from xcdat.regridder.base import BaseRegridder
from xcdat.regridder.grid import _validate_grid_has_single_axis_dim

HorizontalRegridTools = Literal["xesmf", "regrid2"]
HORIZONTAL_REGRID_TOOLS = {
HORIZONTAL_REGRID_TOOLS: dict[str, type[BaseRegridder]] = {
"regrid2": regrid2.Regrid2Regridder,
"xesmf": xesmf.XESMFRegridder,
}

VerticalRegridTools = Literal["xgcm"]
VERTICAL_REGRID_TOOLS = {"xgcm": xgcm.XGCMRegridder}
VERTICAL_REGRID_TOOLS: dict[str, type[BaseRegridder]] = {"xgcm": xgcm.XGCMRegridder}


@xr.register_dataset_accessor(name="regridder")
Expand Down Expand Up @@ -166,7 +167,9 @@ def horizontal(
f"Tool {e!s} does not exist, valid choices {list(HORIZONTAL_REGRID_TOOLS)}"
) from e

input_grid = _get_input_grid(self._ds, data_var, ["X", "Y"])
input_grid = _get_input_grid(
self._ds, data_var, ["X", "Y"], multidim=regrid_tool.can_handle_multidim()
)
regridder = regrid_tool(input_grid, output_grid, **options)
output_ds = regridder.horizontal(data_var, self._ds)

Expand Down Expand Up @@ -236,20 +239,17 @@ def vertical(
f"Tool {e!s} does not exist, valid choices "
f"{list(VERTICAL_REGRID_TOOLS)}"
) from e
input_grid = _get_input_grid(
self._ds,
data_var,
[
"Z",
],
)

input_grid = _get_input_grid(self._ds, data_var, ["Z"])
regridder = regrid_tool(input_grid, output_grid, **options)
output_ds = regridder.vertical(data_var, self._ds)

return output_ds


def _obj_to_grid_ds(obj: xr.Dataset | xr.DataArray) -> xr.Dataset:
def _obj_to_grid_ds(
obj: xr.Dataset | xr.DataArray, multidim: bool = False
) -> xr.Dataset:
"""
Convert an xarray object to a new Dataset containing axis coordinates and
bounds.
Expand Down Expand Up @@ -304,14 +304,20 @@ def _obj_to_grid_ds(obj: xr.Dataset | xr.DataArray) -> xr.Dataset:
attrs=obj.attrs,
)

# Multidimensional coordinates bounds generation is not supported
if multidim:
return output_ds

# Add bounds only for axes that do not already have them. This
# prevents multiple sets of bounds being added for the same axis.
# For example, curvilinear grids can have multiple coordinates for the
# same axis (e.g., (nlat, lat) for X and (nlon, lon) for Y). We only
# need lat_bnds and lon_bnds for the X and Y axes, respectively, and not
# nlat_bnds and nlon_bnds.

for axis, has_bounds in axis_has_bounds.items():
if not has_bounds:
# FIXME: Line 313 --Can't add bounds for multidimensional coordinates
output_ds = output_ds.bounds.add_bounds(axis=axis)

return output_ds
Expand Down Expand Up @@ -347,7 +353,12 @@ def _get_axis_coord_and_bounds(
return coord_var, bounds_var


def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKey]):
def _get_input_grid(
ds: xr.Dataset,
data_var: str,
dup_check_dims: list[CFAxisKey],
multidim: bool = False,
):
"""
Extract the grid from ``ds``.

Expand All @@ -374,10 +385,12 @@ def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKe
all_coords = set(ds.coords.keys())

for dimension in dup_check_dims:
coords = get_dim_coords(ds, dimension)
coords = get_dim_coords(ds, dimension, multidim=multidim)

if isinstance(coords, xr.Dataset):
coord = set([get_dim_coords(ds[data_var], dimension).name])
coord = set(
[get_dim_coords(ds[data_var], dimension, multidim=multidim).name]
)

dimension_coords = set(ds.cf[[dimension]].coords.keys())

Expand All @@ -387,7 +400,7 @@ def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKe
input_grid = ds.drop_dims(to_drop)

# drops extra dimensions on input grid
grid = input_grid.regridder.grid
grid = _obj_to_grid_ds(input_grid, multidim=multidim)

# preserve mask on grid
if "mask" in ds:
Expand Down
6 changes: 6 additions & 0 deletions xcdat/regridder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def _drop_axis(ds: xr.Dataset, axis: list[CFAxisKey]) -> xr.Dataset:
class BaseRegridder(abc.ABC):
"""BaseRegridder."""

supports_multidim: bool

@classmethod
def can_handle_multidim(cls) -> bool:
return cls.supports_multidim

def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: Any):
self._input_grid = input_grid
self._output_grid = output_grid
Expand Down
2 changes: 2 additions & 0 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


class Regrid2Regridder(BaseRegridder):
supports_multidim = False

def __init__(
self,
input_grid: xr.Dataset,
Expand Down
2 changes: 2 additions & 0 deletions xcdat/regridder/xesmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@


class XESMFRegridder(BaseRegridder):
supports_multidim = True

def __init__(
self,
input_grid: xr.Dataset,
Expand Down
2 changes: 2 additions & 0 deletions xcdat/regridder/xgcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class XGCMRegridder(BaseRegridder):
supports_multidim = False

def __init__(
self,
input_grid: xr.Dataset,
Expand Down
Loading