Skip to content

Commit 5bcad8f

Browse files
committed
add tests for weighting function
1 parent 119d0af commit 5bcad8f

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tests/test_scoring/test_losses.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222

2323
PRECISION_VALUE = 0.0005
2424

25-
# TODO add channel issue
26-
# TODO test weighting function
27-
# TODO test if numpy weights and torch weights are the same!
25+
# TODO test channel issue
2826

2927
@pytest.fixture
3028
def rand_targets():
@@ -42,6 +40,15 @@ def ones_targets():
4240
def ones_predics():
4341
return torch.ones(size=(batch_size, out_time, lat, lon)) + 0.1
4442

43+
def test_weights(ones_targets):
44+
lat_size = int(ones_targets.shape[-2])
45+
parent_loss = ClimateSetLoss()
46+
torch_weights = parent_loss.get_latitude_weights(lat_size)
47+
assert torch_weights[0] == pytest.approx(0.0044, abs=0.0001)
48+
assert torch_weights[-1] == pytest.approx(0.0044, abs=0.0001)
49+
assert torch_weights[0] == torch_weights[-1]
50+
assert torch_weights[int(lat_size/2)] == pytest.approx(1, abs=0.001)
51+
4552
def expected_loss(loss_obj, expected_loss_value, precision_threshold):
4653
assert loss_obj.item() == pytest.approx(expected_loss_value, abs=precision_threshold)
4754
assert loss_obj.shape == ()

0 commit comments

Comments
 (0)