Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sbi/inference/trainers/npse/vector_field_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,10 @@ def _loss(
Returns:
Calibration kernel-weighted loss implemented by the vector field estimator.
"""

if times is not None:
times = times.to(self._device)

cls_name = self.__class__.__name__
if self._round == 0 or force_first_round_loss:
# First round loss.
Expand Down
59 changes: 43 additions & 16 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
NPE,
NPE_A,
NPE_C,
NPSE,
NRE_A,
NRE_B,
NRE_C,
Expand All @@ -38,7 +37,6 @@
from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
from sbi.inference.potentials.ratio_based_potential import RatioBasedPotential
from sbi.inference.trainers.npse.vector_field_inference import VectorFieldInference
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.neural_nets.factory import (
classifier_nn,
Expand Down Expand Up @@ -625,7 +623,12 @@ def test_direct_posterior_on_gpu(device: str, device_inference: str):
],
)
def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotential]):
"""Test to method on potential"""
"""Test .to() method on potential.

Args:
device: device where to move the model.
potential: potential to train the model on.
"""

device = process_device(device)
prior = BoxUniform(torch.tensor([1.0]), torch.tensor([1.0]))
Expand Down Expand Up @@ -665,7 +668,12 @@ def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotentia
"sampling_method", ["rejection", "importance", "mcmc", "direct"]
)
def test_to_method_on_posteriors(device: str, sampling_method: str):
"""Test that the .to() method works on posteriors."""
"""Test .to() method on posteriors.

Args:
device: device to train and sample the model on.
sampling_method: method to sample from the posterior.
"""
device = process_device(device)
prior = BoxUniform(torch.zeros(3), torch.ones(3))
inference = NPE()
Expand Down Expand Up @@ -708,25 +716,44 @@ def test_to_method_on_posteriors(device: str, sampling_method: str):

@pytest.mark.gpu
@pytest.mark.parametrize("device", ["cpu", "gpu"])
@pytest.mark.parametrize("iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss"])
@pytest.mark.parametrize("inference_method", [FMPE, NPSE])
def test_VectorFieldPosterior(
device: str, iid_method: str, inference_method: VectorFieldInference
@pytest.mark.parametrize("device_inference", ["cpu", "gpu"])
@pytest.mark.parametrize(
"iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None]
)
def test_VectorFieldPosterior_device_handling(
device: str, device_inference: str, iid_method: str
):
"""Test VectorFieldPosterior on different devices training and inference devices.

Args:
device: device to train the model on.
device_inference: device to run the inference on.
iid_method: method to sample from the posterior.
"""
device = process_device(device)
prior = BoxUniform(torch.zeros(3), torch.ones(3), device="cpu")
inference = inference_method(score_estimator="mlp", prior=prior)
device_inference = process_device(device_inference)
prior = BoxUniform(torch.zeros(3), torch.ones(3), device=device)
inference = FMPE(score_estimator="mlp", prior=prior, device=device)
density_estimator = inference.append_simulations(
torch.randn((100, 3)), torch.randn((100, 2))
).train(max_num_epochs=1)
posterior = inference.build_posterior(density_estimator, prior)
posterior.to(device)
assert posterior.device == device, (
f"VectorFieldPosterior is not in device {device}."

# faster but inaccurate log_prob computation
posterior.potential_fn.neural_ode.update_params(exact=False, atol=1e-4, rtol=1e-4)

posterior.to(device_inference)
assert posterior.device == device_inference, (
f"VectorFieldPosterior is not in device {device_inference}."
)

x_o = torch.ones(2).to(device)
x_o = torch.ones(2).to(device_inference)
samples = posterior.sample((2,), x=x_o, iid_method=iid_method)
assert samples.device.type == device.split(":")[0], (
f"Samples are not on device {device}."
assert samples.device.type == device_inference.split(":")[0], (
f"Samples are not on device {device_inference}."
)

log_probs = posterior.log_prob(samples, x=x_o)
assert log_probs.device.type == device_inference.split(":")[0], (
f"log_prob was not correctly moved to {device_inference}."
)