Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1778 +/- ##
==========================================
- Coverage 87.65% 86.16% -1.50%
==========================================
Files 140 143 +3
Lines 12178 12503 +325
==========================================
+ Hits 10675 10773 +98
- Misses 1503 1730 +227
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Looks overall great already! I added a comment on a few minor issues.
What is currently still missing is adding some tests i.e.:
By keeping the context size small should allow these tests to run relatively fast, no? Otherwise we might mark most of them as slow.
| ) | ||
|
|
||
|
|
||
| def infer_module_device(module: torch.nn.Module, fallback: str) -> str: |
There was a problem hiding this comment.
This seems to be a unrelated but general replacement for
# device = str(next(estimator.parameters()).device) to be replaced
device = infer_module_device(estimator, "cpu")But this is a improvement + adds an informative warning. But add a small doctoring.
| "torch>=1.13.0", | ||
| "tqdm", | ||
| "zuko>=1.2.0", | ||
| "tabpfn", |
There was a problem hiding this comment.
Maybe @janfb if you have any opinion on adding it as fixed dependency. The only extra dependencies this adds is pydantic, eval-type-backport, tabpfn-common-utils[telemetry-interactive] and filelock, which all have light dependencies. But one can also think about making it an optional dependency only installed when needed. I am fine with adding it.
| f"Expected estimator to be TabPFNFlow, got {type(estimator).__name__}." | ||
| ) | ||
|
|
||
| if sample_with == "filtered_direct": |
There was a problem hiding this comment.
If sample_with is anything different from "filtered_direct" this branching will be skiped i.e. no context will be set.
The standard way to build a posterior would still work I think, but I think we would still set the context (i.e. train dataset) here to the estimator. Not sure if this would be covered by tests.
Alternatively we might raise a NotImplementedError for all other methods for now.
| return self._posterior | ||
|
|
||
| if full_data_size > estimator.max_context_size: | ||
| warn( |
There was a problem hiding this comment.
Not sure if this warning here is necessary as TabPFN will anyway throw and informative error.
| ) | ||
|
|
||
| if sample_with == "filtered_direct": | ||
| full_context_input, full_context_condition, _ = self.get_simulations( |
There was a problem hiding this comment.
This is already fetched in line 260, no?
|
|
||
| self.max_context_size = int(max_context_size) | ||
| self._input_numel = int(torch.Size(input_shape).numel()) | ||
| self.register_buffer("_context_input", None, persistent=False) |
There was a problem hiding this comment.
In the current implementation these always will have to sit on the CPU side. So not sure if it makes sense to register them as buffers i.e. which would move them to GPU if we move the estimator to GPU. But then it will anyway always be first transfered to CPU.
| r"""Public wrapper for preparing embedded, flattened conditions.""" | ||
| return self._embed_condition(condition) | ||
|
|
||
| def set_context_flat( |
There was a problem hiding this comment.
This kinda checks for TabPFN required shapes, right? Then I would document this in the docstring. Also it would make sense to mention that we require CPU side context (as TabPFN numpy API requires that, right?)
|
|
||
| def _autoregressive_sample( | ||
| self, condition_flat: Tensor, with_log_prob: bool = False, eps: float = 1e-15 | ||
| ) -> tuple[Tensor, Optional[Tensor]]: |
There was a problem hiding this comment.
The second tensor is not optional, no?
|
|
||
| dim_log_prob = torch.where( | ||
| dim_log_prob == float("-inf"), | ||
| torch.log(torch.tensor(eps)), |
There was a problem hiding this comment.
Small thing: Can be computed outside the loop.
|
|
||
| dim_log_prob = torch.where( | ||
| dim_log_prob == float("-inf"), | ||
| torch.log(torch.tensor(eps)), |
There was a problem hiding this comment.
Also here compute outside of loop once.
The goal of this PR is to add NPE-PFN to SBI, as discussed in #1682.
The implementation is realized mostly by three new components, which I will briefly describe in the following.
Happy to discuss all of this, as the exisiting assumptions encoded trough base classes like
NeuralInferenceorConditionalDensityEstimatorsometimes make more and sometimes make less sense for NPE-PFN.There are three key files that implement the method:
1.)
tabpfn_flow.pyimplements the in-contextConditionalDensityEstimatorbased on the autoregressive use of TabPFN. It behaves exactly like other estimators, and given some context dataset provides sampling and log-prob functionality.2.)
npe_pfn.pyimplements theNPE_PFNclass which, inherits fromNeuralInferenceand implements the basic logic used across the package (append_simulations,train,build_posterioretc.). Since NPE-PFN is training free, thetrainmethod is a stub, and most functionality is handled directly bybuild_posterior. This allows users to calltrainwithout breaking any previous workflow, but they can also "forget" about it as would be suggested by a training-free method.Since the TabPFN-based flow behaves like any other flow, NPE-PFN supports out-of-the-box many different types of posteriors (Direct, Rejection, IS, could add more, but inference is too slow for MCMC). However, a crucial feature of NPE-PFN is filtering, where the context dataset is selected based on a given observation.
To support this functionality, a new posterior class is required.
3.)
filtered_direct_posterior.pyimplements this posterior (inheriting fromDirectPosterior), which allows filtering based on different filters (usually KNN, but users can also provide a custom callable).There are many other smaller changes (builders, dataclasses, etc.) and so far no tests.
Also, this PR contains the core functionality for amortized inference. More advanced stuff like sequential inference, or even support for finetuning etc. (which we didn't even do in the paper) are not added.
It probably makes sense to dicuss this approach first, before I add fine-grained tests or possibly more functionality.
Here are results for the mini benchmark:
