File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed
Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change 2222
2323PRECISION_VALUE = 0.0005
2424
25- # TODO add channel issue
26- # TODO test weighting function
27- # TODO test if numpy weights and torch weights are the same!
25+ # TODO test channel issue
2826
2927@pytest .fixture
3028def rand_targets ():
@@ -42,6 +40,15 @@ def ones_targets():
4240def ones_predics ():
4341 return torch .ones (size = (batch_size , out_time , lat , lon )) + 0.1
4442
43+ def test_weights (ones_targets ):
44+ lat_size = int (ones_targets .shape [- 2 ])
45+ parent_loss = ClimateSetLoss ()
46+ torch_weights = parent_loss .get_latitude_weights (lat_size )
47+ assert torch_weights [0 ] == pytest .approx (0.0044 , abs = 0.0001 )
48+ assert torch_weights [- 1 ] == pytest .approx (0.0044 , abs = 0.0001 )
49+ assert torch_weights [0 ] == torch_weights [- 1 ]
50+ assert torch_weights [int (lat_size / 2 )] == pytest .approx (1 , abs = 0.001 )
51+
4552def expected_loss (loss_obj , expected_loss_value , precision_threshold ):
4653 assert loss_obj .item () == pytest .approx (expected_loss_value , abs = precision_threshold )
4754 assert loss_obj .shape == ()
You can’t perform that action at this time.
0 commit comments