Skip to content

1446 add api for guidance#1482

Merged
manuelgloeckler merged 50 commits intomainfrom
1446-add-api-for-guidance
Mar 2, 2026
Merged

1446 add api for guidance#1482
manuelgloeckler merged 50 commits intomainfrom
1446-add-api-for-guidance

Conversation

@manuelgloeckler
Copy link
Contributor

@manuelgloeckler manuelgloeckler commented Mar 18, 2025

This adds an API for post-hoc modifications of the trained score estimator, allowing modifications of the likelihood or prior, the support, or other additional constraints. This addresses issue #1446 .

To does:

  • Adds general API inline with current iid_method
  • Refactor the parameterizations from stringly type to strongly typed (also for iid!)
  • Adds some useful explanatory guidance approaches:
    • Classifier free guidance
    • Universal guidance - interval truncations.

@manuelgloeckler manuelgloeckler linked an issue Mar 18, 2025 that may be closed by this pull request
5 tasks
@manuelgloeckler manuelgloeckler marked this pull request as draft March 18, 2025 18:48
@manuelgloeckler manuelgloeckler self-assigned this Mar 18, 2025
@manuelgloeckler manuelgloeckler marked this pull request as ready for review March 20, 2025 07:05
@codecov
Copy link

codecov bot commented Mar 20, 2025

Codecov Report

❌ Patch coverage is 72.76479% with 198 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.59%. Comparing base (937efc2) to head (b0e90f1).
⚠️ Report is 24 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
sbi/utils/vector_field_utils.py 65.78% 103 Missing ⚠️
sbi/inference/potentials/vector_field_adaptor.py 77.48% 93 Missing ⚠️
sbi/inference/potentials/vector_field_potential.py 84.61% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1482      +/-   ##
==========================================
- Coverage   88.54%   82.59%   -5.95%     
==========================================
  Files         137      139       +2     
  Lines       11515    12544    +1029     
==========================================
+ Hits        10196    10361     +165     
- Misses       1319     2183     +864     
Flag Coverage Δ
fast 82.59% <72.76%> (?)
full ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/posteriors/vector_field_posterior.py 68.70% <ø> (-8.48%) ⬇️
...i/neural_nets/estimators/flowmatching_estimator.py 81.66% <ø> (-15.00%) ⬇️
sbi/inference/potentials/vector_field_potential.py 75.40% <84.61%> (-15.67%) ⬇️
sbi/inference/potentials/vector_field_adaptor.py 77.48% <77.48%> (ø)
sbi/utils/vector_field_utils.py 66.12% <65.78%> (-17.21%) ⬇️

... and 70 files with indirect coverage changes

@manuelgloeckler manuelgloeckler added the blocked Something is in the way of fixing this. Refer to it in the issue label Mar 24, 2025
@manuelgloeckler manuelgloeckler removed the blocked Something is in the way of fixing this. Refer to it in the issue label Sep 5, 2025
@manuelgloeckler manuelgloeckler requested a review from janfb February 2, 2026 16:49
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @manuelgloeckler 👏 Great to have PriorGuide in here as well 🚀

added mostly minor formatting and edge case handling comments.

More high-level, I am concerned about the long signature sample(...) has by now. We should think about introducing config classes for the different methods (sde config, iid, guidance), or method chaining (see comment below). But this would be a larger refactoring - let's discuss.

On the test side, maybe add another test on the combination of guidance and iid settings, given that this implemented as an option?

Documentation: great to have a tutorial as part of advanced tutorial 20 already. As a follow-up, I suggest a refactoring of tutorials 19 and 20, and 1-2 how-to-guide for the VF and guidance methods.

Comment on lines +162 to +163
guidance_method: Optional[str] = None,
guidance_params: Optional[Dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the sample method now has >20 kwargs. we could think about introducing config classes instead, e.g., sth like

from sbi.inference.posteriors import GuidanceConfig, PriorGuideCfg

guidance = GuidanceConfig(
    method="prior_guide",
    params=PriorGuideCfg(train_prior=..., test_prior=..., K=5)
)
samples = posterior.sample((1000,), x=x_o, guidance=guidance)

but this would be a larger, possibly follow-up refactoring. what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or even sth like

samples = posterior.with_guidance("prior_guide", ...).sample((1000,), x=x_o)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhh, I think config classes would be fine (I internally use config classes anyway). But at least at the time of implementation most user API was still configured via dicts so I made this the external API.

I do actually like your suggestions with with_guidance and with_iid to make it more explicit. But yes I would make this a follow-up refactoring.

Comment on lines +179 to +191
class _HashableById:
__slots__ = ("obj", "_id")

def __init__(self, obj: Distribution):
"""Wraps a non-hashable Distribution to make it cache-key compatible."""
self.obj = obj
self._id = id(obj)

def __hash__(self) -> int:
return self._id

def __eq__(self, other: object) -> bool:
return isinstance(other, _HashableById) and self._id == other._id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall great to have this here! the main purpose is to not refit the GMM on every call from the potential, right?

just wondering, this caches by object ID not distribution parameters, so if someone passes the same dist with same params but a new instance, this won't have an effect. Accordingly, when someone changes the distribution params in place this will not be noticed in the cache.
But I think both are edge cases so it's fine

Copy link
Contributor Author

@manuelgloeckler manuelgloeckler Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there are definitely some edge cases where this will fail. This is why it is only used if required.

I generally a bit unhappy that the adaptor/score wrapper objects are reinitialized in each call of potential (which does require to hash as much as possible). I think it would be good the restructure the VF potential a bit that such quantities can be computed once at the beginning i.e. via an initalize_aux which needs to be called once berfore sample.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make a issue for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add it to the GSoC project on the potential_fn refactoring as well.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good, again, great work @manuelgloeckler ! 👏

if corresponding slow tests are passing we can merge this!

@manuelgloeckler manuelgloeckler merged commit 0f7e6ea into main Mar 2, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add API for guidance

4 participants