Skip to content

posterior.to(device)#1527

Merged
janfb merged 45 commits intosbi-dev:mainfrom
jorobledo:posterior_device
Mar 27, 2025
Merged

posterior.to(device)#1527
janfb merged 45 commits intosbi-dev:mainfrom
jorobledo:posterior_device

Conversation

@jorobledo
Copy link
Collaborator

This PR is related to issue #1368. It adds to method to all potentials and posteriors, as well as the corresponding tests.

They were tested in cpu, gpu at our cluster, and on mps on macs. All tests pass on cuda devices.

Three tests fail on mps due to the use of torch.Distributions like Binomial or Gamma for MultipleIndependent, which give a torch NotImplementedError on MPS device YET. I can imagine that these will soon pass with newer torch versions.

The posterior.to(device) method is inplace.

@codecov
Copy link

codecov bot commented Mar 21, 2025

Codecov Report

Attention: Patch coverage is 35.32609% with 119 lines in your changes missing coverage. Please review.

Project coverage is 78.94%. Comparing base (8900ca0) to head (f2642b5).
Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
sbi/inference/posteriors/ensemble_posterior.py 37.50% 15 Missing ⚠️
sbi/inference/posteriors/vector_field_posterior.py 21.05% 15 Missing ⚠️
sbi/inference/posteriors/direct_posterior.py 22.22% 14 Missing ⚠️
sbi/inference/posteriors/vi_posterior.py 25.00% 12 Missing ⚠️
sbi/inference/posteriors/mcmc_posterior.py 15.38% 11 Missing ⚠️
sbi/inference/posteriors/importance_posterior.py 28.57% 10 Missing ⚠️
sbi/inference/posteriors/rejection_posterior.py 33.33% 10 Missing ⚠️
sbi/inference/potentials/base_potential.py 25.00% 6 Missing ⚠️
.../inference/potentials/posterior_based_potential.py 25.00% 6 Missing ⚠️
sbi/inference/potentials/vector_field_potential.py 25.00% 6 Missing ⚠️
... and 3 more

❌ Your patch status has failed because the patch coverage (35.32%) is below the target coverage (50.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1527       +/-   ##
===========================================
- Coverage   89.45%   78.94%   -10.52%     
===========================================
  Files         128      133        +5     
  Lines       10170    10368      +198     
===========================================
- Hits         9098     8185      -913     
- Misses       1072     2183     +1111     
Flag Coverage Δ
unittests 78.94% <35.32%> (-10.52%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/posteriors/base_posterior.py 85.18% <ø> (-2.47%) ⬇️
sbi/inference/trainers/marginal/marginal_base.py 96.94% <100.00%> (+0.02%) ⬆️
sbi/neural_nets/estimators/categorical_net.py 93.75% <100.00%> (ø)
sbi/samplers/mcmc/init_strategy.py 84.21% <100.00%> (-5.27%) ⬇️
sbi/samplers/score/diffuser.py 86.66% <100.00%> (-5.00%) ⬇️
sbi/samplers/score/predictors.py 97.43% <100.00%> (ø)
sbi/utils/potentialutils.py 100.00% <100.00%> (ø)
sbi/utils/sbiutils.py 78.46% <100.00%> (-9.08%) ⬇️
sbi/utils/user_input_checks_utils.py 89.83% <100.00%> (+7.97%) ⬆️
sbi/inference/potentials/score_fn_iid.py 88.62% <33.33%> (-2.47%) ⬇️
... and 12 more

... and 25 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jorobledo
Copy link
Collaborator Author

Regarding type checking failing, it says that prior are of type Distributions which doesn't have to method.

I'm not sure which solution should we give to this. They are defined as Distributions, but with our previous pull-request, they should be one of BoxUniform, PytorchReturnTypeWrapper, MultipleIndependent, or CustomPriorWrapper, since these are the classes that have the to(device) method. Any suggestions?

@jorobledo jorobledo requested a review from gmoss13 March 23, 2025 18:38
@jorobledo
Copy link
Collaborator Author

Thanks @StarostinV, super clear!. Now I've separated it into another test, and included also all other iid_methods:

@pytest.mark.gpu
@pytest.mark.parametrize("device", ["cpu", "gpu"])
@pytest.mark.parametrize("iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss"])
@pytest.mark.parametrize("inference_method", [FMPE, NPSE])
def test_VectorFieldPosterior(
device: str, iid_method: str, inference_method: VectorFieldInference
):
device = process_device(device)
prior = BoxUniform(torch.zeros(3), torch.ones(3), device="cpu")
inference = inference_method(score_estimator="mlp", prior=prior)
density_estimator = inference.append_simulations(
torch.randn((100, 3)), torch.randn((100, 2))
).train()
posterior = inference.build_posterior(density_estimator, prior)
posterior.to(device)
assert posterior.device == device, f"ScorePosterior is not in device {device}."
x_o = torch.ones(2).to(device)
samples = posterior.sample((2,), x=x_o, iid_method=iid_method)
assert samples.device.type == device.split(":")[0], (
f"Samples are not on device {device}."
)

Test are passing, thing is that they are marked as gpu therefore codecov is low. I think this PR is ready for revision, let me know @janfb

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing! Minor typo below, but otherwise good to go from my side!

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, thanks a lot for all the GPU fixes! 🎉

I found some typos and possible improvements for the typing, all minor things. I tried to put in concrete suggestions which you could just accept and commit in a batch.

When testing locally, I got three failing GPU tests, see comment below.

@janfb
Copy link
Contributor

janfb commented Mar 27, 2025

About the three failing tests:

  1. pytest tests/inference_on_device_test.py::test_conditioned_posterior_on_gpu: The problem is that here
    https://github.com/jorobledo/sbi/blob/8878a81c795ae332f74d92c867e18c56cc2e5164/sbi/utils/sbiutils.py#L923-L928
    the init_probs come from the potential_fn living on the device, but inits are not on the device. The fix would probably be moving the inits to the device because later in the function best_theta_iter (coming from inits) is again used together with the potential_fn.

And there seems to another problem I could not figure out until now: during .map, the prior_transform is not on the correct device, and the prior is not either, even though it's instantiated with device=device in the boundary tensors. Using to(device) fixes it, but that's a bit weird.

  1. pytest tests/marginal_estimator_test.py::test_marginal_estimator: Here, we need two fixes:

  2. pytest tests/mnpe_test.py::test_mnpe_on_device: It seems the device handling MultipleIndependent got mixed up during a merge from main. At the moment, the test is xfail and still raises a NotImplementedError when instantiated with a device that is not None. There seem to be some device handling bugs in MNPE as well. Can you grant to write access to your fork so that I can fix it? @jorobledo

jorobledo and others added 2 commits March 27, 2025 13:14
Co-authored-by: Jan <janfb@users.noreply.github.com>
@janfb
Copy link
Contributor

janfb commented Mar 27, 2025

@jorobledo I pushed the fixes to https://github.com/sbi-dev/sbi/tree/posterior_device_fixes in separate commits. feel free to cherry-pick from that branch or just them over.

@jorobledo
Copy link
Collaborator Author

jorobledo commented Mar 27, 2025

Way to go! Looking good, Anything left uncovered? Didn't read the comment to grant access! Sorry about that.

@jorobledo
Copy link
Collaborator Author

Getting 238 passed, 6 skipped, 2988 deselected, 62 xfailed, 71 warnings in 683.19s (0:11:23) when running pytest tests -m "gpu" on a Tesla V100-PCIE-16GB

@janfb
Copy link
Contributor

janfb commented Mar 27, 2025

Getting 238 passed, 6 skipped, 2988 deselected, 62 xfailed, 71 warnings in 683.19s (0:11:23) when running pytest tests -m "gpu" on a Tesla V100-PCIE-16GB

GPU tests are all passing for me as well 🎉

@janfb
Copy link
Contributor

janfb commented Mar 27, 2025

Happy to fix the last CI test that keeps failing @jorobledo.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @jorobledo ! It's amazing to have this feature 🎉

Looks all good now and alles tests are passing, so let's :shipit:

@janfb janfb merged commit fb5124e into sbi-dev:main Mar 27, 2025
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants