Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def _check_writable(self) -> None:
if not self.writeable:
raise ValueError("store mode does not support writing")

@abstractmethod
def __eq__(self, value: object) -> bool:
"""Equality comparison."""
...

@abstractmethod
async def get(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/zarr/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ def __add__(self, other: Buffer) -> Self:
np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array)))
)

def __eq__(self, other: object) -> bool:
# Note: this was needed to support comparing MemoryStore instances with Buffer values in them
# if/when we stopped putting buffers in memory stores, this can be removed
return isinstance(other, type(self)) and self.to_bytes() == other.to_bytes()
Comment thread
jhamman marked this conversation as resolved.
Outdated


class NDBuffer:
"""An n-dimensional memory block
Expand Down
13 changes: 13 additions & 0 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"MemoryStore({str(self)!r})"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self._store_dict == other._store_dict
and self.mode == other.mode
)

def __setstate__(self, state: tuple[MutableMapping[str, Buffer], OpenMode]) -> None:
raise NotImplementedError(f"{type(self)} cannot be pickled")

def __getstate__(self) -> tuple[MutableMapping[str, Buffer], OpenMode]:
raise NotImplementedError(f"{type(self)} cannot be pickled")

Comment thread
jhamman marked this conversation as resolved.
async def get(
self,
key: str,
Expand Down
10 changes: 10 additions & 0 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
"""

super().__init__(mode=mode)
self._storage_options = storage_options
if isinstance(url, str):
self._url = url.rstrip("/")
self._fs, _path = fsspec.url_to_fs(url, **storage_options)
Expand Down Expand Up @@ -81,6 +82,15 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"<RemoteStore({type(self._fs).__name__}, {self.path})>"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self.path == other.path
and self.mode == other.mode
and self._url == other._url
# and self._storage_options == other._storage_options # FIXME: this isn't working for some reason
)

async def get(
self,
key: str,
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from typing import Any, Generic, TypeVar

import pytest
Expand Down Expand Up @@ -42,6 +43,19 @@ def test_store_type(self, store: S) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None:
# check self equality
assert store == store

# check store equality with same inputs
# asserting this is important for being able to compare (de)serialized stores
store2 = self.store_cls(**store_kwargs)
Comment thread
jhamman marked this conversation as resolved.
assert store == store2

def test_serizalizable_store(self, store: S) -> None:
foo = pickle.dumps(store)
assert pickle.loads(foo) == store

def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None:
assert store.mode == "w", store.mode
assert store.writeable
Expand Down
38 changes: 37 additions & 1 deletion tests/v3/test_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pickle

import numpy as np
import pytest

from zarr.array import Array
from zarr.array import Array, AsyncArray
from zarr.common import ZarrFormat
from zarr.group import Group
from zarr.store import LocalStore, MemoryStore
Expand Down Expand Up @@ -34,3 +37,36 @@ def test_array_name_properties_with_group(
assert spam.path == "bar/spam"
assert spam.name == "/bar/spam"
assert spam.basename == "spam"


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
Comment thread
jhamman marked this conversation as resolved.
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serizalizable_async_array(
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
expected = await AsyncArray.create(
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
)
# await expected.setitems(list(range(100)))

p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
# np.testing.assert_array_equal(await actual.getitem(slice(None)), await expected.getitem(slice(None)))
# TODO: uncomment the parts of this test that will be impacted by the config/prototype changes in flight


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
def test_serizalizable_sync_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
Comment thread
jhamman marked this conversation as resolved.
Outdated
expected = Array.create(
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
)
expected[:] = list(range(100))

p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
np.testing.assert_array_equal(actual[:], expected[:])
24 changes: 24 additions & 0 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pickle
from typing import TYPE_CHECKING, Any

from zarr.array import AsyncArray
Expand Down Expand Up @@ -391,3 +392,26 @@ def test_group_name_properties(store: LocalStore | MemoryStore, zarr_format: Zar
assert bar.path == "foo/bar"
assert bar.name == "/foo/bar"
assert bar.basename == "bar"


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serizalizable_async_group(
Comment thread
jhamman marked this conversation as resolved.
Outdated
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
expected = await AsyncGroup.create(
store=store, attributes={"foo": 999}, zarr_format=zarr_format
)
p = pickle.dumps(expected)
actual = pickle.loads(p)
assert actual == expected


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
def test_serizalizable_sync_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
Comment thread
jhamman marked this conversation as resolved.
Outdated
expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format)
p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
12 changes: 12 additions & 0 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import pytest

from zarr.buffer import Buffer
Expand Down Expand Up @@ -38,3 +40,13 @@ def test_store_supports_partial_writes(self, store: MemoryStore) -> None:

def test_list_prefix(self, store: MemoryStore) -> None:
assert True

def test_serizalizable_store(self, store: MemoryStore) -> None:
with pytest.raises(NotImplementedError):
store.__getstate__()

with pytest.raises(NotImplementedError):
store.__setstate__({})

with pytest.raises(NotImplementedError):
pickle.dumps(store)
2 changes: 1 addition & 1 deletion tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
anon = False
mode = "w"
if request.param == "use_upath":
return {"mode": mode, "url": UPath(url, endpoint_url=endpoint_url, anon=anon)}
return {"url": UPath(url, endpoint_url=endpoint_url, anon=anon), "mode": mode}
elif request.param == "use_str":
return {"url": url, "mode": mode, "anon": anon, "endpoint_url": endpoint_url}

Expand Down