diff --git a/adaptive/learner/average_learner.py b/adaptive/learner/average_learner.py index ba7deb1f1..b571c19d3 100644 --- a/adaptive/learner/average_learner.py +++ b/adaptive/learner/average_learner.py @@ -118,9 +118,11 @@ def loss(self, real=True, *, n=None): if n < self.min_npoints: return np.inf standard_error = self.std / sqrt(n) - return max( - standard_error / self.atol, standard_error / abs(self.mean) / self.rtol - ) + aloss = standard_error / self.atol + rloss = standard_error / self.rtol + if self.mean != 0: + rloss /= abs(self.mean) + return max(aloss, rloss) def _loss_improvement(self, n): loss = self.loss() diff --git a/adaptive/tests/test_average_learner.py b/adaptive/tests/test_average_learner.py index 42d4726d0..f35794a39 100644 --- a/adaptive/tests/test_average_learner.py +++ b/adaptive/tests/test_average_learner.py @@ -59,3 +59,11 @@ def constant_function(seed): ) simple(learner, lambda l: l.loss() < 1) assert learner.npoints >= max(2, min_npoints) + + +def test_zero_mean(): + # see https://github.com/python-adaptive/adaptive/issues/275 + learner = AverageLearner(None, rtol=0.01) + learner.tell(0, -1) + learner.tell(1, 1) + learner.loss()