diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 84cd1d97323..f8959f64fc7 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -1241,6 +1241,33 @@ shape of each coordinate array in the ``encoding`` argument: The number of chunks on Tair matches our dask chunks, while there is now only a single chunk in the directory stores of each coordinate. +.. _io.zarr.rectilinear_chunks: + +Variable-sized (rectilinear) chunks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Zarr v3 supports *rectilinear* chunk grids, where chunk sizes vary along one +or more dimensions. This is useful when natural data boundaries (yearly chunks +of a daily time series, per-tile spatial extents) don't align to a regular +grid. Requires ``zarr-python >= 3.2``, and the experimental feature must be +enabled for both reading and writing: + +.. code-block:: python + + import zarr + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds = xr.Dataset({"var": ("x", np.arange(60))}).chunk({"x": (10, 20, 30)}) + ds.to_zarr("rectilinear.zarr", zarr_format=3, mode="w") + + roundtrip = xr.open_zarr("rectilinear.zarr", zarr_format=3) + roundtrip.chunks["x"] # (10, 20, 30) + +Rectilinear chunks can also be specified through the ``encoding`` argument +(one sequence per dimension), e.g. ``encoding={"var": {"chunks": ((10, 20, 30),)}}``. +Writing non-uniform chunks to a zarr v2 store raises ``ValueError`` because the +feature is Zarr Format 3-only. + Groups ~~~~~~ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index effb199f18e..23c1110f616 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,6 +14,11 @@ v2026.05.0 (unreleased) New Features ~~~~~~~~~~~~ +- Support reading and writing Zarr V3 arrays with rectilinear (variable-sized) + chunk grids. Requires zarr-python >= 3.2 with + ``zarr.config.set({"array.rectilinear_chunks": True})``, which must be set + for both reading and writing rectilinear-chunked stores. (:pull:`11279`). + By `Max Jones `_. Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/backends/chunks.py b/xarray/backends/chunks.py index c255c7db591..d859032b0a2 100644 --- a/xarray/backends/chunks.py +++ b/xarray/backends/chunks.py @@ -1,3 +1,5 @@ +import itertools + import numpy as np from xarray.core.datatree import Variable @@ -133,9 +135,12 @@ def align_nd_chunks( def build_grid_chunks( size: int, - chunk_size: int, + chunk_size: int | tuple[int, ...], region: slice | None = None, ) -> tuple[int, ...]: + if isinstance(chunk_size, (list, tuple)): + return _build_rectilinear_grid_chunks(chunk_size, region) + if region is None: region = slice(0, size) @@ -153,9 +158,39 @@ def build_grid_chunks( return tuple(chunks_on_region) +def _build_rectilinear_grid_chunks( + chunk_sizes: tuple[int, ...], + region: slice | None = None, +) -> tuple[int, ...]: + """Build grid chunks for a rectilinear dimension within a region.""" + if region is None or region == slice(None): + return tuple(chunk_sizes) + + region_start = region.start or 0 + region_stop = region.stop or sum(chunk_sizes) + + boundaries = [0] + for cs in chunk_sizes: + boundaries.append(boundaries[-1] + cs) + + result = [] + for i in range(len(chunk_sizes)): + chunk_start = boundaries[i] + chunk_end = boundaries[i + 1] + + if chunk_end <= region_start or chunk_start >= region_stop: + continue + + effective_start = max(chunk_start, region_start) + effective_end = min(chunk_end, region_stop) + result.append(effective_end - effective_start) + + return tuple(result) + + def grid_rechunk( v: Variable, - enc_chunks: tuple[int, ...], + enc_chunks: tuple[int, ...] | tuple[int | tuple[int, ...], ...], region: tuple[slice, ...], ) -> Variable: nd_v_chunks = v.chunks @@ -181,9 +216,36 @@ def grid_rechunk( return v +def _validate_rectilinear_chunk_alignment( + dask_chunks: tuple[int, ...], + enc_chunks: tuple[int, ...], + axis: int, + name: str, + region: slice = slice(None), +) -> None: + """Validate dask chunks align with rectilinear encoding chunk boundaries.""" + enc_stops = set(itertools.accumulate(enc_chunks)) + region_start = region.start or 0 + dask_stops = {region_start + s for s in itertools.accumulate(dask_chunks)} + # The final stop (total size) always matches — exclude it + total = sum(enc_chunks) + enc_stops.discard(total) + dask_stops.discard(total) + bad = dask_stops - enc_stops + if bad: + raise ValueError( + f"Specified rectilinear encoding chunks {enc_chunks!r} for variable " + f"named {name!r} would overlap multiple Dask chunks on axis {axis}. " + f"Dask chunk boundaries at positions {sorted(bad)} do not align with " + f"encoding chunk boundaries at {sorted(enc_stops)}. " + "Writing this array in parallel with Dask could lead to corrupted data. " + "Consider rechunking using `chunk()` or setting `safe_chunks=False`." + ) + + def validate_grid_chunks_alignment( nd_v_chunks: tuple[tuple[int, ...], ...] | None, - enc_chunks: tuple[int, ...], + enc_chunks: tuple[int | tuple[int, ...], ...], backend_shape: tuple[int, ...], region: tuple[slice, ...], allow_partial_chunks: bool, @@ -213,6 +275,18 @@ def validate_grid_chunks_alignment( backend_shape, strict=True, ): + if isinstance(chunk_size, (list, tuple)): + # Rectilinear dimension — use boundary-based validation + _validate_rectilinear_chunk_alignment( + dask_chunks=v_chunks, + enc_chunks=chunk_size, + axis=axis, + name=name, + region=interval, + ) + continue + + # Regular dimension — existing validation logic for i, chunk in enumerate(v_chunks[1:-1]): if chunk % chunk_size: raise ValueError( diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index d9279dc2de9..ac18d9fbcbb 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -46,6 +46,10 @@ from xarray.core.types import ZarrArray, ZarrGroup +def _has_rectilinear_chunks() -> bool: + return module_available("zarr", minversion="3.2") + + def _get_mappers(*, storage_options, store, chunk_store): # expand str and path-like arguments store = _normalize_path(store) @@ -333,7 +337,7 @@ async def async_getitem(self, key): ) -def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): +def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, zarr_format): """ Given encoding chunks (possibly None or []) and variable chunks (possibly None or []). @@ -355,18 +359,34 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): # while dask chunks can be variable sized # https://dask.pydata.org/en/latest/array-design.html#chunks if var_chunks and not enc_chunks: + if zarr_format == 3 and _has_rectilinear_chunks(): + # Check if dask chunks are regular (uniform except for last chunk) + has_varying_interior = any( + len(set(chunks[:-1])) > 1 for chunks in var_chunks + ) + has_larger_final = any(chunks[0] < chunks[-1] for chunks in var_chunks) + if has_varying_interior or has_larger_final: + # Truly rectilinear — return dask-style tuples of per-chunk sizes. + # Requires zarr config: array.rectilinear_chunks = True + return tuple(var_chunks) + # Regular chunks — return the first chunk size per dimension + return tuple(chunk[0] for chunk in var_chunks) + if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks): raise ValueError( - "Zarr requires uniform chunk sizes except for final chunk. " + "Zarr v2 requires uniform chunk sizes except for the final chunk. " f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. " - "Consider rechunking using `chunk()`." + "Consider rechunking using `chunk()`, or switching to the " + "zarr v3 format with zarr-python>=3.2." ) if any((chunks[0] < chunks[-1]) for chunks in var_chunks): raise ValueError( - "Final chunk of Zarr array must be the same size or smaller " - f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}." - "Consider either rechunking using `chunk()` or instead deleting " - "or modifying `encoding['chunks']`." + "The final chunk of a Zarr v2 array or a Zarr v3 array without the " + "rectilinear chunks extension must be the same size or smaller " + f"than the first. Variable named {name!r} has incompatible Dask " + f"chunks {var_chunks!r}. " + "Consider switching to Zarr v3 with the rectilinear chunks extension, " + "rechunking using `chunk()` or deleting or modifying `encoding['chunks']`." ) # return the first chunk for each dimension return tuple(chunk[0] for chunk in var_chunks) @@ -389,8 +409,17 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): var_chunks, ndim, name, + zarr_format, ) + # Rectilinear chunks: each element is a sequence of per-chunk edge lengths + if ( + zarr_format == 3 + and _has_rectilinear_chunks() + and any(not isinstance(x, int) for x in enc_chunks_tuple) + ): + return enc_chunks_tuple + for x in enc_chunks_tuple: if not isinstance(x, int): raise TypeError( @@ -532,6 +561,7 @@ def extract_zarr_variable_encoding( var_chunks=variable.chunks, ndim=variable.ndim, name=name, + zarr_format=zarr_format, ) if _zarr_v3() and chunks is None: chunks = "auto" @@ -910,9 +940,16 @@ def open_store_variable(self, name): ) attributes = dict(attributes) + try: + chunks = tuple(zarr_array.chunks) + except NotImplementedError: + # Rectilinear chunk grid (zarr >= 3.2) — chunks vary along the axis + chunks = zarr_array.write_chunk_sizes + preferred_chunks = dict(zip(dimensions, chunks, strict=True)) + encoding = { - "chunks": zarr_array.chunks, - "preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)), + "chunks": chunks, + "preferred_chunks": preferred_chunks, } if _zarr_v3(): diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 70697cf68ce..6907d379984 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -134,6 +134,9 @@ def _importorskip( has_zarr_v3, requires_zarr_v3 = _importorskip("zarr", "3.0.0") has_zarr_v3_dtypes, requires_zarr_v3_dtypes = _importorskip("zarr", "3.1.0") has_zarr_v3_async_oindex, requires_zarr_v3_async_oindex = _importorskip("zarr", "3.1.2") +has_zarr_rectilinear_chunks, requires_zarr_rectilinear_chunks = _importorskip( + "zarr", "3.2.0" +) if has_zarr_v3: import zarr diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e42bfc2cd9f..318b339deb3 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -79,6 +79,7 @@ has_numpy_2, has_scipy, has_zarr, + has_zarr_rectilinear_chunks, has_zarr_v3, has_zarr_v3_async_oindex, has_zarr_v3_dtypes, @@ -102,6 +103,7 @@ requires_scipy, requires_scipy_or_netCDF4, requires_zarr, + requires_zarr_rectilinear_chunks, requires_zarr_v3, ) from xarray.tests.test_coding_times import ( @@ -2973,9 +2975,18 @@ def test_chunk_encoding_with_dask(self) -> None: # should fail if dask_chunks are irregular... ds_chunk_irreg = ds.chunk({"x": (5, 4, 3)}) - with pytest.raises(ValueError, match=r"uniform chunk sizes."): - with self.roundtrip(ds_chunk_irreg) as actual: - pass + if ( + has_zarr_rectilinear_chunks + and zarr.config.config["default_zarr_format"] == 3 + ): + # zarr v3 with unified chunk grid supports rectilinear chunks + with zarr.config.set({"array.rectilinear_chunks": True}): + with self.roundtrip(ds_chunk_irreg) as actual: + pass + else: + with pytest.raises(ValueError, match=r"uniform chunk sizes."): + with self.roundtrip(ds_chunk_irreg) as actual: + pass # should fail if encoding["chunks"] clashes with dask_chunks badenc = ds.chunk({"x": 4}) @@ -7299,6 +7310,257 @@ def test_extract_zarr_variable_encoding() -> None: ) +@requires_zarr_rectilinear_chunks +@requires_dask +def test_rectilinear_chunks_encoding_roundtrip(tmp_path: Path) -> None: + """Rectilinear chunk sizes in encoding are passed through to zarr v3.""" + + import zarr + + chunk_sizes = [10, 20, 30] + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)}) + + store_path = tmp_path / "rectilinear.zarr" + encoding = {"var": {"chunks": [chunk_sizes]}} + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w", encoding=encoding) + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == tuple(chunk_sizes) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_rectilinear_chunks +@requires_dask +def test_rectilinear_chunks_no_encoding(tmp_path: Path) -> None: + """Variable dask chunks are written as rectilinear when no encoding is given.""" + import zarr + + chunk_sizes = [15, 25, 20] + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)}) + + store_path = tmp_path / "rectilinear_no_enc.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == tuple(chunk_sizes) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_rectilinear_chunks +@requires_dask +def test_rectilinear_chunks_multidim(tmp_path: Path) -> None: + """Rectilinear chunks on a multi-dimensional array.""" + import zarr + + data = np.arange(120, dtype="float64").reshape(6, 20) + ds = xr.Dataset({"var": xr.Variable(("x", "y"), data)}).chunk( + {"x": (2, 4), "y": (5, 10, 5)} + ) + + store_path = tmp_path / "rectilinear_2d.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == (2, 4) + assert roundtrip.chunks["y"] == (5, 10, 5) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_rectilinear_chunks +@requires_dask +def test_rectilinear_chunks_mixed_dims(tmp_path: Path) -> None: + """One dimension regular, another rectilinear.""" + import zarr + + data = np.arange(60, dtype="float32").reshape(3, 20) + ds = xr.Dataset({"var": xr.Variable(("x", "y"), data)}).chunk( + {"x": 3, "y": (5, 10, 5)} + ) + + store_path = tmp_path / "mixed_chunks.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == (3,) + assert roundtrip.chunks["y"] == (5, 10, 5) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_rectilinear_chunks +@requires_dask +def test_rectilinear_chunks_interop(tmp_path: Path) -> None: + """Read rectilinear array created directly by zarr.""" + import zarr + + store_path = tmp_path / "zarr_native.zarr" + data = np.arange(60, dtype="float32") + + with zarr.config.set({"array.rectilinear_chunks": True}): + root = zarr.open_group(store_path, mode="w", zarr_format=3) + arr = root.create( + "var", + shape=(60,), + # zarr stubs don't include rectilinear chunk types yet + chunks=((10, 20, 30),), # type: ignore[arg-type] + dtype="float32", + dimension_names=("x",), + ) + arr[:] = data + + roundtrip = xr.open_zarr(store_path, zarr_format=3, consolidated=False) + assert roundtrip.chunks["x"] == (10, 20, 30) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_rectilinear_chunks +@requires_dask +def test_rectilinear_chunks_safe_chunks_fail(tmp_path: Path) -> None: + """Misaligned dask chunks should raise when safe_chunks=True.""" + import zarr + + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": (15, 15, 30)}) + + store_path = tmp_path / "safe_chunks_fail.zarr" + encoding = {"var": {"chunks": [(10, 20, 30)]}} + + with zarr.config.set({"array.rectilinear_chunks": True}): + with pytest.raises(ValueError, match=r"rectilinear.*overlap"): + ds.to_zarr( + store_path, + zarr_format=3, + mode="w", + encoding=encoding, + safe_chunks=True, + ) + + +@requires_zarr_rectilinear_chunks +@requires_dask +def test_rectilinear_chunks_region_write(tmp_path: Path) -> None: + """Write to a region of a rectilinear chunked array.""" + import zarr + + chunk_sizes = (10, 20, 30) + data = np.zeros(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": chunk_sizes}) + + store_path = tmp_path / "region.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + # Overwrite just the second chunk (positions 10..30) + update = np.arange(20, dtype="float32") + 100 + ds_update = xr.Dataset({"var": xr.Variable("x", update)}).chunk({"x": (20,)}) + ds_update.to_zarr( + store_path, zarr_format=3, region={"x": slice(10, 30)}, mode="r+" + ) + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + expected = data.copy() + expected[10:30] = update + np.testing.assert_array_equal(roundtrip["var"].values, expected) + assert roundtrip.chunks["x"] == chunk_sizes + + +@requires_zarr_rectilinear_chunks +@requires_dask +def test_rectilinear_chunks_encoding_roundtrip_rewrite(tmp_path: Path) -> None: + """Read a rectilinear array and write it back preserving chunks.""" + import zarr + + chunk_sizes = (10, 20, 30) + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": chunk_sizes}) + + path1 = tmp_path / "source.zarr" + path2 = tmp_path / "dest.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(path1, zarr_format=3, mode="w") + + loaded = xr.open_zarr(path1, zarr_format=3) + loaded.to_zarr(path2, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(path2, zarr_format=3) + assert roundtrip.chunks["x"] == chunk_sizes + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +def test_validate_grid_chunks_alignment_rectilinear_pass() -> None: + """Dask chunks that align with rectilinear zarr boundaries should pass.""" + from xarray.backends.chunks import validate_grid_chunks_alignment + + validate_grid_chunks_alignment( + nd_v_chunks=((10, 20, 30),), + enc_chunks=((10, 20, 30),), + region=(slice(None),), + allow_partial_chunks=True, + name="var", + backend_shape=(60,), + ) + + # Dask chunks are coarser (merging zarr chunks is fine) + validate_grid_chunks_alignment( + nd_v_chunks=((30, 30),), + enc_chunks=((10, 20, 30),), + region=(slice(None),), + allow_partial_chunks=True, + name="var", + backend_shape=(60,), + ) + + +def test_validate_grid_chunks_alignment_rectilinear_fail() -> None: + """Dask chunks that split a rectilinear zarr chunk should raise.""" + from xarray.backends.chunks import validate_grid_chunks_alignment + + with pytest.raises(ValueError, match=r"rectilinear.*overlap"): + validate_grid_chunks_alignment( + nd_v_chunks=((15, 15, 30),), + enc_chunks=((10, 20, 30),), + region=(slice(None),), + allow_partial_chunks=True, + name="var", + backend_shape=(60,), + ) + + +def test_build_grid_chunks_rectilinear_full() -> None: + """build_grid_chunks with rectilinear spec and no region returns the spec.""" + from xarray.backends.chunks import build_grid_chunks + + result = build_grid_chunks(size=60, chunk_size=(10, 20, 30)) + assert result == (10, 20, 30) + + +def test_build_grid_chunks_rectilinear_region() -> None: + """build_grid_chunks with rectilinear spec and a region clips to region.""" + from xarray.backends.chunks import build_grid_chunks + + result = build_grid_chunks(size=45, chunk_size=(10, 20, 30), region=slice(15, 60)) + assert result == (15, 30) + + +def test_build_grid_chunks_rectilinear_region_mid() -> None: + """Region that starts and ends mid-chunk.""" + from xarray.backends.chunks import build_grid_chunks + + result = build_grid_chunks(size=40, chunk_size=(10, 20, 30), region=slice(5, 45)) + assert result == (5, 20, 15) + + @requires_zarr @requires_fsspec @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")