diff --git a/sbi/inference/trainers/npse/vector_field_inference.py b/sbi/inference/trainers/npse/vector_field_inference.py index 64cad5a02..fe0156802 100644 --- a/sbi/inference/trainers/npse/vector_field_inference.py +++ b/sbi/inference/trainers/npse/vector_field_inference.py @@ -585,6 +585,10 @@ def _loss( Returns: Calibration kernel-weighted loss implemented by the vector field estimator. """ + + if times is not None: + times = times.to(self._device) + cls_name = self.__class__.__name__ if self._round == 0 or force_first_round_loss: # First round loss. diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index f40b5d3b9..cf969f20c 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -21,7 +21,6 @@ NPE, NPE_A, NPE_C, - NPSE, NRE_A, NRE_B, NRE_C, @@ -38,7 +37,6 @@ from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential from sbi.inference.potentials.ratio_based_potential import RatioBasedPotential -from sbi.inference.trainers.npse.vector_field_inference import VectorFieldInference from sbi.neural_nets.embedding_nets import FCEmbedding from sbi.neural_nets.factory import ( classifier_nn, @@ -625,7 +623,12 @@ def test_direct_posterior_on_gpu(device: str, device_inference: str): ], ) def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotential]): - """Test to method on potential""" + """Test .to() method on potential. + + Args: + device: device where to move the model. + potential: potential to train the model on. + """ device = process_device(device) prior = BoxUniform(torch.tensor([1.0]), torch.tensor([1.0])) @@ -665,7 +668,12 @@ def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotentia "sampling_method", ["rejection", "importance", "mcmc", "direct"] ) def test_to_method_on_posteriors(device: str, sampling_method: str): - """Test that the .to() method works on posteriors.""" + """Test .to() method on posteriors. + + Args: + device: device to train and sample the model on. + sampling_method: method to sample from the posterior. + """ device = process_device(device) prior = BoxUniform(torch.zeros(3), torch.ones(3)) inference = NPE() @@ -708,25 +716,44 @@ def test_to_method_on_posteriors(device: str, sampling_method: str): @pytest.mark.gpu @pytest.mark.parametrize("device", ["cpu", "gpu"]) -@pytest.mark.parametrize("iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss"]) -@pytest.mark.parametrize("inference_method", [FMPE, NPSE]) -def test_VectorFieldPosterior( - device: str, iid_method: str, inference_method: VectorFieldInference +@pytest.mark.parametrize("device_inference", ["cpu", "gpu"]) +@pytest.mark.parametrize( + "iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None] +) +def test_VectorFieldPosterior_device_handling( + device: str, device_inference: str, iid_method: str ): + """Test VectorFieldPosterior on different devices training and inference devices. + + Args: + device: device to train the model on. + device_inference: device to run the inference on. + iid_method: method to sample from the posterior. + """ device = process_device(device) - prior = BoxUniform(torch.zeros(3), torch.ones(3), device="cpu") - inference = inference_method(score_estimator="mlp", prior=prior) + device_inference = process_device(device_inference) + prior = BoxUniform(torch.zeros(3), torch.ones(3), device=device) + inference = FMPE(score_estimator="mlp", prior=prior, device=device) density_estimator = inference.append_simulations( torch.randn((100, 3)), torch.randn((100, 2)) ).train(max_num_epochs=1) posterior = inference.build_posterior(density_estimator, prior) - posterior.to(device) - assert posterior.device == device, ( - f"VectorFieldPosterior is not in device {device}." + + # faster but inaccurate log_prob computation + posterior.potential_fn.neural_ode.update_params(exact=False, atol=1e-4, rtol=1e-4) + + posterior.to(device_inference) + assert posterior.device == device_inference, ( + f"VectorFieldPosterior is not in device {device_inference}." ) - x_o = torch.ones(2).to(device) + x_o = torch.ones(2).to(device_inference) samples = posterior.sample((2,), x=x_o, iid_method=iid_method) - assert samples.device.type == device.split(":")[0], ( - f"Samples are not on device {device}." + assert samples.device.type == device_inference.split(":")[0], ( + f"Samples are not on device {device_inference}." + ) + + log_probs = posterior.log_prob(samples, x=x_o) + assert log_probs.device.type == device_inference.split(":")[0], ( + f"log_prob was not correctly moved to {device_inference}." )