Skip to content

Add NPE-PFN#1778

Open
jsvetter wants to merge 27 commits intosbi-dev:mainfrom
jsvetter:npe_pfn_dev
Open

Add NPE-PFN#1778
jsvetter wants to merge 27 commits intosbi-dev:mainfrom
jsvetter:npe_pfn_dev

Conversation

@jsvetter
Copy link
Contributor

@jsvetter jsvetter commented Feb 24, 2026

The goal of this PR is to add NPE-PFN to SBI, as discussed in #1682.

The implementation is realized mostly by three new components, which I will briefly describe in the following.
Happy to discuss all of this, as the exisiting assumptions encoded trough base classes like NeuralInference or ConditionalDensityEstimator sometimes make more and sometimes make less sense for NPE-PFN.

There are three key files that implement the method:

1.) tabpfn_flow.py implements the in-context ConditionalDensityEstimator based on the autoregressive use of TabPFN. It behaves exactly like other estimators, and given some context dataset provides sampling and log-prob functionality.

2.) npe_pfn.py implements the NPE_PFN class which, inherits from NeuralInference and implements the basic logic used across the package (append_simulations, train, build_posterior etc.). Since NPE-PFN is training free, the train method is a stub, and most functionality is handled directly by build_posterior. This allows users to call train without breaking any previous workflow, but they can also "forget" about it as would be suggested by a training-free method.

Since the TabPFN-based flow behaves like any other flow, NPE-PFN supports out-of-the-box many different types of posteriors (Direct, Rejection, IS, could add more, but inference is too slow for MCMC). However, a crucial feature of NPE-PFN is filtering, where the context dataset is selected based on a given observation.

To support this functionality, a new posterior class is required.

3.) filtered_direct_posterior.py implements this posterior (inheriting from DirectPosterior), which allows filtering based on different filters (usually KNN, but users can also provide a custom callable).

There are many other smaller changes (builders, dataclasses, etc.) and so far no tests.
Also, this PR contains the core functionality for amortized inference. More advanced stuff like sequential inference, or even support for finetuning etc. (which we didn't even do in the paper) are not added.

It probably makes sense to dicuss this approach first, before I add fine-grained tests or possibly more functionality.

Here are results for the mini benchmark:
Bildschirmfoto 2026-02-26 um 11 15 33

@codecov
Copy link

codecov bot commented Feb 24, 2026

Codecov Report

❌ Patch coverage is 31.13772% with 230 lines in your changes missing coverage. Please review.
✅ Project coverage is 86.16%. Comparing base (d41efa6) to head (690df4a).

Files with missing lines Patch % Lines
sbi/neural_nets/estimators/tabpfn_flow.py 20.00% 104 Missing ⚠️
sbi/inference/trainers/npe/npe_pfn.py 36.66% 57 Missing ⚠️
.../inference/posteriors/filtered_direct_posterior.py 34.32% 44 Missing ⚠️
sbi/neural_nets/net_builders/flow.py 20.00% 8 Missing ⚠️
sbi/inference/posteriors/posterior_parameters.py 53.84% 6 Missing ⚠️
sbi/utils/torchutils.py 40.00% 6 Missing ⚠️
sbi/utils/user_input_checks.py 0.00% 3 Missing ⚠️
sbi/inference/trainers/base.py 60.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1778      +/-   ##
==========================================
- Coverage   87.65%   86.16%   -1.50%     
==========================================
  Files         140      143       +3     
  Lines       12178    12503     +325     
==========================================
+ Hits        10675    10773      +98     
- Misses       1503     1730     +227     
Flag Coverage Δ
fast 81.90% <31.13%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/__init__.py 100.00% <100.00%> (ø)
sbi/inference/posteriors/__init__.py 100.00% <100.00%> (ø)
.../inference/potentials/posterior_based_potential.py 89.36% <100.00%> (ø)
sbi/inference/trainers/npe/__init__.py 100.00% <100.00%> (ø)
sbi/neural_nets/factory.py 81.92% <ø> (ø)
sbi/neural_nets/net_builders/__init__.py 100.00% <ø> (ø)
sbi/inference/trainers/base.py 93.03% <60.00%> (-0.51%) ⬇️
sbi/utils/user_input_checks.py 76.68% <0.00%> (ø)
sbi/inference/posteriors/posterior_parameters.py 80.57% <53.84%> (-2.76%) ⬇️
sbi/utils/torchutils.py 67.77% <40.00%> (-1.64%) ⬇️
... and 4 more

Copy link
Contributor

@manuelgloeckler manuelgloeckler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks overall great already! I added a comment on a few minor issues.

What is currently still missing is adding some tests i.e.:

  • standard test suite for DensityEstimators here
  • add tests on device here

By keeping the context size small should allow these tests to run relatively fast, no? Otherwise we might mark most of them as slow.

)


def infer_module_device(module: torch.nn.Module, fallback: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a unrelated but general replacement for

# device = str(next(estimator.parameters()).device) to be replaced
device = infer_module_device(estimator, "cpu")

But this is a improvement + adds an informative warning. But add a small doctoring.

"torch>=1.13.0",
"tqdm",
"zuko>=1.2.0",
"tabpfn",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @janfb if you have any opinion on adding it as fixed dependency. The only extra dependencies this adds is pydantic, eval-type-backport, tabpfn-common-utils[telemetry-interactive] and filelock, which all have light dependencies. But one can also think about making it an optional dependency only installed when needed. I am fine with adding it.

f"Expected estimator to be TabPFNFlow, got {type(estimator).__name__}."
)

if sample_with == "filtered_direct":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sample_with is anything different from "filtered_direct" this branching will be skiped i.e. no context will be set.

The standard way to build a posterior would still work I think, but I think we would still set the context (i.e. train dataset) here to the estimator. Not sure if this would be covered by tests.

Alternatively we might raise a NotImplementedError for all other methods for now.

return self._posterior

if full_data_size > estimator.max_context_size:
warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this warning here is necessary as TabPFN will anyway throw and informative error.

)

if sample_with == "filtered_direct":
full_context_input, full_context_condition, _ = self.get_simulations(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already fetched in line 260, no?


self.max_context_size = int(max_context_size)
self._input_numel = int(torch.Size(input_shape).numel())
self.register_buffer("_context_input", None, persistent=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current implementation these always will have to sit on the CPU side. So not sure if it makes sense to register them as buffers i.e. which would move them to GPU if we move the estimator to GPU. But then it will anyway always be first transfered to CPU.

r"""Public wrapper for preparing embedded, flattened conditions."""
return self._embed_condition(condition)

def set_context_flat(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kinda checks for TabPFN required shapes, right? Then I would document this in the docstring. Also it would make sense to mention that we require CPU side context (as TabPFN numpy API requires that, right?)


def _autoregressive_sample(
self, condition_flat: Tensor, with_log_prob: bool = False, eps: float = 1e-15
) -> tuple[Tensor, Optional[Tensor]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second tensor is not optional, no?


dim_log_prob = torch.where(
dim_log_prob == float("-inf"),
torch.log(torch.tensor(eps)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small thing: Can be computed outside the loop.


dim_log_prob = torch.where(
dim_log_prob == float("-inf"),
torch.log(torch.tensor(eps)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here compute outside of loop once.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants