Conversation
…v/sbi into 1446-add-api-for-guidance
- add pytest-split plugin to dev dependencies - use split in ci workflow based on test durations in .test_durations
…v/sbi into 1446-add-api-for-guidance
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1482 +/- ##
==========================================
- Coverage 88.54% 82.59% -5.95%
==========================================
Files 137 139 +2
Lines 11515 12544 +1029
==========================================
+ Hits 10196 10361 +165
- Misses 1319 2183 +864
Flags with carried forward coverage won't be shown. Click here to find out more.
|
Update the branch
…her i.e. score -> vector field
There was a problem hiding this comment.
Great work @manuelgloeckler 👏 Great to have PriorGuide in here as well 🚀
added mostly minor formatting and edge case handling comments.
More high-level, I am concerned about the long signature sample(...) has by now. We should think about introducing config classes for the different methods (sde config, iid, guidance), or method chaining (see comment below). But this would be a larger refactoring - let's discuss.
On the test side, maybe add another test on the combination of guidance and iid settings, given that this implemented as an option?
Documentation: great to have a tutorial as part of advanced tutorial 20 already. As a follow-up, I suggest a refactoring of tutorials 19 and 20, and 1-2 how-to-guide for the VF and guidance methods.
| guidance_method: Optional[str] = None, | ||
| guidance_params: Optional[Dict] = None, |
There was a problem hiding this comment.
the sample method now has >20 kwargs. we could think about introducing config classes instead, e.g., sth like
from sbi.inference.posteriors import GuidanceConfig, PriorGuideCfg
guidance = GuidanceConfig(
method="prior_guide",
params=PriorGuideCfg(train_prior=..., test_prior=..., K=5)
)
samples = posterior.sample((1000,), x=x_o, guidance=guidance)
but this would be a larger, possibly follow-up refactoring. what do you think?
There was a problem hiding this comment.
or even sth like
samples = posterior.with_guidance("prior_guide", ...).sample((1000,), x=x_o)There was a problem hiding this comment.
Mhh, I think config classes would be fine (I internally use config classes anyway). But at least at the time of implementation most user API was still configured via dicts so I made this the external API.
I do actually like your suggestions with with_guidance and with_iid to make it more explicit. But yes I would make this a follow-up refactoring.
| class _HashableById: | ||
| __slots__ = ("obj", "_id") | ||
|
|
||
| def __init__(self, obj: Distribution): | ||
| """Wraps a non-hashable Distribution to make it cache-key compatible.""" | ||
| self.obj = obj | ||
| self._id = id(obj) | ||
|
|
||
| def __hash__(self) -> int: | ||
| return self._id | ||
|
|
||
| def __eq__(self, other: object) -> bool: | ||
| return isinstance(other, _HashableById) and self._id == other._id |
There was a problem hiding this comment.
overall great to have this here! the main purpose is to not refit the GMM on every call from the potential, right?
just wondering, this caches by object ID not distribution parameters, so if someone passes the same dist with same params but a new instance, this won't have an effect. Accordingly, when someone changes the distribution params in place this will not be noticed in the cache.
But I think both are edge cases so it's fine
There was a problem hiding this comment.
Yeah, there are definitely some edge cases where this will fail. This is why it is only used if required.
I generally a bit unhappy that the adaptor/score wrapper objects are reinitialized in each call of potential (which does require to hash as much as possible). I think it would be good the restructure the VF potential a bit that such quantities can be computed once at the beginning i.e. via an initalize_aux which needs to be called once berfore sample.
There was a problem hiding this comment.
I will make a issue for that.
There was a problem hiding this comment.
We could add it to the GSoC project on the potential_fn refactoring as well.
janfb
left a comment
There was a problem hiding this comment.
Looks all good, again, great work @manuelgloeckler ! 👏
if corresponding slow tests are passing we can merge this!
This adds an API for post-hoc modifications of the trained score estimator, allowing modifications of the likelihood or prior, the support, or other additional constraints. This addresses issue #1446 .
To does: