Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: Tests

on:
- push
- pull_request

jobs:
pytest:
Expand Down
19 changes: 16 additions & 3 deletions sacred/config/custom_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,19 @@ def type_changed(old_value, new_value):
def is_different(old_value, new_value):
"""Numpy aware comparison between two values."""
if opt.has_numpy:
return not opt.np.array_equal(old_value, new_value)
else:
return old_value != new_value
# Reproduces np.array_equal from numpy<2
# np.array_equal raises an exception when the arguments are scalar and
# differ in type (e.g. int and str) in numpy>=2.0
try:
old_value = opt.np.asarray(old_value)
new_value = opt.np.asarray(new_value)
except:
return False
else:
result = old_value == new_value
if isinstance(result, bool):
return result
else:
return result.all()

return old_value != new_value
2 changes: 0 additions & 2 deletions tests/test_config/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"uint16",
"uint32",
"uint64",
"float_",
"float16",
"float32",
"float64",
Expand All @@ -49,7 +48,6 @@ def test_normalize_or_die_for_numpy_datatypes(typename):
"uint16",
"uint32",
"uint64",
"float_",
"float16",
"float32",
"float64",
Expand Down
16 changes: 8 additions & 8 deletions tests/test_observers/test_s3_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_file_data(bucket_name, key):
return s3.Object(bucket_name, key).get()["Body"].read()


@moto.mock_s3
@moto.mock_aws
def test_fs_observer_started_event_creates_bucket(observer, sample_run):
_id = observer.started_event(**sample_run)
run_dir = s3_join(BASEDIR, str(_id))
Expand All @@ -102,7 +102,7 @@ def test_fs_observer_started_event_creates_bucket(observer, sample_run):
}


@moto.mock_s3
@moto.mock_aws
def test_fs_observer_started_event_increments_run_id(observer, sample_run):
_id = observer.started_event(**sample_run)
_id2 = observer.started_event(**sample_run)
Expand All @@ -119,15 +119,15 @@ def test_s3_observer_equality():
assert obs_one != different_bucket


@moto.mock_s3
@moto.mock_aws
def test_raises_error_on_duplicate_id_directory(observer, sample_run):
observer.started_event(**sample_run)
sample_run["_id"] = 1
with pytest.raises(FileExistsError):
observer.started_event(**sample_run)


@moto.mock_s3
@moto.mock_aws
def test_completed_event_updates_run_json(observer, sample_run):
observer.started_event(**sample_run)
run = json.loads(
Expand All @@ -145,7 +145,7 @@ def test_completed_event_updates_run_json(observer, sample_run):
assert run["status"] == "COMPLETED"


@moto.mock_s3
@moto.mock_aws
def test_interrupted_event_updates_run_json(observer, sample_run):
observer.started_event(**sample_run)
run = json.loads(
Expand All @@ -163,7 +163,7 @@ def test_interrupted_event_updates_run_json(observer, sample_run):
assert run["status"] == "SERVER_EXPLODED"


@moto.mock_s3
@moto.mock_aws
def test_failed_event_updates_run_json(observer, sample_run):
observer.started_event(**sample_run)
run = json.loads(
Expand All @@ -181,7 +181,7 @@ def test_failed_event_updates_run_json(observer, sample_run):
assert run["status"] == "FAILED"


@moto.mock_s3
@moto.mock_aws
def test_queued_event_updates_run_json(observer, sample_run):
del sample_run["start_time"]
sample_run["queue_time"] = T2
Expand All @@ -194,7 +194,7 @@ def test_queued_event_updates_run_json(observer, sample_run):
assert run["status"] == "QUEUED"


@moto.mock_s3
@moto.mock_aws
def test_artifact_event_works(observer, sample_run, tmpfile):
observer.started_event(**sample_run)
observer.artifact_event("test_artifact.py", tmpfile.name)
Expand Down
52 changes: 9 additions & 43 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# and then run "tox" from this directory.

[tox]
envlist = py{38,39,310,311}, setup, numpy-{120,121,123}, tensorflow-{26,27,28,29,210,211}
envlist = py{38,39,310,311}, setup, numpy-{120,121,123,200}, tensorflow-{212,216}

[testenv]
deps =
Expand Down Expand Up @@ -53,68 +53,34 @@ deps =
commands =
pytest tests/test_config {posargs}

[testenv:tensorflow-115]
[testenv:numpy-200]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=1.15.0
numpy~=2.0.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-26]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.6.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-27]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.7.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}
pytest tests/test_config {posargs}

[testenv:tensorflow-28]
[testenv:tensorflow-212]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.8.0
numpy<2.0.0
tensorflow~=2.12.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-29]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.9.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-210]
[testenv:tensorflow-216]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.10.0
tensorflow~=2.16.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:tensorflow-211]
basepython = python
deps =
-rdev-requirements.txt
tensorflow~=2.11.0
commands =
pytest tests/test_stflow tests/test_optional.py \
{posargs}

[testenv:setup]
basepython = python
Expand Down