From 1ddb28050fbbb360a2c5f991bf0cae097d2e72e5 Mon Sep 17 00:00:00 2001 From: Jose Robledo Date: Thu, 27 Mar 2025 19:55:23 +0100 Subject: [PATCH 1/4] improve test for VectorFieldPosterior and add docstrings to GPU tests --- tests/inference_on_device_test.py | 60 ++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index f40b5d3b9..bcb6b6ec2 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -21,7 +21,6 @@ NPE, NPE_A, NPE_C, - NPSE, NRE_A, NRE_B, NRE_C, @@ -38,7 +37,6 @@ from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential from sbi.inference.potentials.ratio_based_potential import RatioBasedPotential -from sbi.inference.trainers.npse.vector_field_inference import VectorFieldInference from sbi.neural_nets.embedding_nets import FCEmbedding from sbi.neural_nets.factory import ( classifier_nn, @@ -625,7 +623,12 @@ def test_direct_posterior_on_gpu(device: str, device_inference: str): ], ) def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotential]): - """Test to method on potential""" + """Test to method on potential. + + Args: + device: device where to move the model. + potential: potential to train the model on. + """ device = process_device(device) prior = BoxUniform(torch.tensor([1.0]), torch.tensor([1.0])) @@ -665,7 +668,12 @@ def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotentia "sampling_method", ["rejection", "importance", "mcmc", "direct"] ) def test_to_method_on_posteriors(device: str, sampling_method: str): - """Test that the .to() method works on posteriors.""" + """Test that the .to() method works on posteriors. + + Args: + device: device to train and sample the model on. + sampling_method: method to sample from the posterior. + """ device = process_device(device) prior = BoxUniform(torch.zeros(3), torch.ones(3)) inference = NPE() @@ -708,25 +716,43 @@ def test_to_method_on_posteriors(device: str, sampling_method: str): @pytest.mark.gpu @pytest.mark.parametrize("device", ["cpu", "gpu"]) -@pytest.mark.parametrize("iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss"]) -@pytest.mark.parametrize("inference_method", [FMPE, NPSE]) -def test_VectorFieldPosterior( - device: str, iid_method: str, inference_method: VectorFieldInference -): +@pytest.mark.parametrize("device_inference", ["cpu", "gpu"]) +@pytest.mark.parametrize( + "iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None] +) +def test_VectorFieldPosterior(device: str, device_inference: str, iid_method: str): + """Test that VectorFieldPosterior works on different devices, + both for training as well as on inference. + + Args: + device: device to train the model on. + device_inference: device to run the inference on. + iid_method: method to sample from the posterior. + """ device = process_device(device) - prior = BoxUniform(torch.zeros(3), torch.ones(3), device="cpu") - inference = inference_method(score_estimator="mlp", prior=prior) + device_inference = process_device(device_inference) + prior = BoxUniform(torch.zeros(3), torch.ones(3), device=device) + inference = FMPE(score_estimator="mlp", prior=prior, device=device) density_estimator = inference.append_simulations( torch.randn((100, 3)), torch.randn((100, 2)) ).train(max_num_epochs=1) posterior = inference.build_posterior(density_estimator, prior) - posterior.to(device) - assert posterior.device == device, ( - f"VectorFieldPosterior is not in device {device}." + + # faster but inaccurate log_prob computation + posterior.potential_fn.neural_ode.update_params(exact=False, atol=1e-4, rtol=1e-4) + + posterior.to(device_inference) + assert posterior.device == device_inference, ( + f"VectorFieldPosterior is not in device {device_inference}." ) - x_o = torch.ones(2).to(device) + x_o = torch.ones(2).to(device_inference) samples = posterior.sample((2,), x=x_o, iid_method=iid_method) - assert samples.device.type == device.split(":")[0], ( - f"Samples are not on device {device}." + assert samples.device.type == device_inference.split(":")[0], ( + f"Samples are not on device {device_inference}." + ) + + log_probs = posterior.log_prob(samples, x=x_o) + assert log_probs.device.type == device_inference.split(":")[0], ( + f"log_prob was not correctly moved to {device_inference}." ) From 225a6dfece51a1f1089ae0f2a321dfaf443698e6 Mon Sep 17 00:00:00 2001 From: Jose Robledo Date: Thu, 27 Mar 2025 20:07:02 +0100 Subject: [PATCH 2/4] fix _loss setting all tensors to the correct device in VectorFieldInference --- sbi/inference/trainers/npse/vector_field_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sbi/inference/trainers/npse/vector_field_inference.py b/sbi/inference/trainers/npse/vector_field_inference.py index 64cad5a02..fe0156802 100644 --- a/sbi/inference/trainers/npse/vector_field_inference.py +++ b/sbi/inference/trainers/npse/vector_field_inference.py @@ -585,6 +585,10 @@ def _loss( Returns: Calibration kernel-weighted loss implemented by the vector field estimator. """ + + if times is not None: + times = times.to(self._device) + cls_name = self.__class__.__name__ if self._round == 0 or force_first_round_loss: # First round loss. From b022f0217b378764d6099da0af0462ea8c3756cf Mon Sep 17 00:00:00 2001 From: Jose Robledo <46170369+jorobledo@users.noreply.github.com> Date: Fri, 28 Mar 2025 11:12:20 +0100 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Jan --- tests/inference_on_device_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index bcb6b6ec2..4455bb8b4 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -623,7 +623,7 @@ def test_direct_posterior_on_gpu(device: str, device_inference: str): ], ) def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotential]): - """Test to method on potential. + """Test .to() method on potential. Args: device: device where to move the model. @@ -668,7 +668,7 @@ def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotentia "sampling_method", ["rejection", "importance", "mcmc", "direct"] ) def test_to_method_on_posteriors(device: str, sampling_method: str): - """Test that the .to() method works on posteriors. + """Test .to() method on posteriors. Args: device: device to train and sample the model on. @@ -720,9 +720,8 @@ def test_to_method_on_posteriors(device: str, sampling_method: str): @pytest.mark.parametrize( "iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None] ) -def test_VectorFieldPosterior(device: str, device_inference: str, iid_method: str): - """Test that VectorFieldPosterior works on different devices, - both for training as well as on inference. +def test_VectorFieldPosterior_device_handling(device: str, device_inference: str, iid_method: str): + """Test VectorFieldPosterior on different devices training and inference devices. Args: device: device to train the model on. From d46cfbe0e9c77c5d90170bc1f644eabec20fea43 Mon Sep 17 00:00:00 2001 From: Jose Robledo Date: Fri, 28 Mar 2025 11:15:22 +0100 Subject: [PATCH 4/4] fix linting --- tests/inference_on_device_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 4455bb8b4..cf969f20c 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -720,7 +720,9 @@ def test_to_method_on_posteriors(device: str, sampling_method: str): @pytest.mark.parametrize( "iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None] ) -def test_VectorFieldPosterior_device_handling(device: str, device_inference: str, iid_method: str): +def test_VectorFieldPosterior_device_handling( + device: str, device_inference: str, iid_method: str +): """Test VectorFieldPosterior on different devices training and inference devices. Args: