Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,40 @@ def likelihood_estimator_based_potential(


class LikelihoodBasedPotential(BasePotential):
r"""Potential function for likelihood-based methods (NLE).

This potential computes $\log p(\theta|x)$ using a trained
likelihood estimator. It is used internally by NLE methods for posterior
sampling via MCMC, rejection sampling, or VI.

Example:
--------
::

import torch
from sbi.inference import NLE_A
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.utils import BoxUniform

# 1. Train a likelihood estimator
prior = BoxUniform(low=torch.zeros(2), high=torch.ones(2))
theta = prior.sample((100,))
x = theta + 0.1 * torch.randn_like(theta)

trainer = NLE_A(prior=prior)
likelihood_estimator = trainer.append_simulations(theta, x).train()

# 2. Create potential function
x_o = torch.tensor([[0.5, 0.5]])
potential_fn, theta_transform = likelihood_estimator_based_potential(
likelihood_estimator, prior, x_o
)

# 3. Evaluate potential at a point
theta_test = torch.tensor([[0.4, 0.6]])
log_prob = potential_fn(theta_test)
"""

def __init__(
self,
likelihood_estimator: ConditionalDensityEstimator,
Expand Down
34 changes: 34 additions & 0 deletions sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,40 @@ def posterior_estimator_based_potential(


class PosteriorBasedPotential(BasePotential):
r"""Potential function for posterior-based methods (NPE).

This potential directly evaluates $\log(p(x_o|\theta)p(\theta))$
from a trained posterior estimator, returning $-\infty$ for
parameters outside prior support.

Example:
--------
::

import torch
from sbi.inference import NPE_C
from sbi.inference.potentials import posterior_estimator_based_potential
from sbi.utils import BoxUniform

# 1. Train a posterior estimator
prior = BoxUniform(low=torch.zeros(2), high=torch.ones(2))
theta = prior.sample((100,))
x = theta + 0.1 * torch.randn_like(theta)

trainer = NPE_C(prior=prior)
posterior_estimator = trainer.append_simulations(theta, x).train()

# 2. Create potential function
x_o = torch.tensor([[0.5, 0.5]])
potential_fn, theta_transform = posterior_estimator_based_potential(
posterior_estimator, prior, x_o
)

# 3. Evaluate potential at a point
theta_test = torch.tensor([[0.4, 0.6]])
log_prob = potential_fn(theta_test)
"""

def __init__(
self,
posterior_estimator: ConditionalDensityEstimator,
Expand Down
34 changes: 34 additions & 0 deletions sbi/inference/potentials/ratio_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,40 @@ def ratio_estimator_based_potential(


class RatioBasedPotential(BasePotential):
r"""Potential function for ratio-based methods (NRE).

This potential computes $\log(r(\theta, x_o) \cdot p(\theta))$ where
$r(\theta, x) = p(x|\theta)/p(x)$ is the likelihood-to-evidence ratio
estimated by a classifier.

Example:
--------
::

import torch
from sbi.inference import NRE_A
from sbi.inference.potentials import ratio_estimator_based_potential
from sbi.utils import BoxUniform

# 1. Train a ratio estimator
prior = BoxUniform(low=torch.zeros(2), high=torch.ones(2))
theta = prior.sample((100,))
x = theta + 0.1 * torch.randn_like(theta)

trainer = NRE_A(prior=prior)
ratio_estimator = trainer.append_simulations(theta, x).train()

# 2. Create potential function
x_o = torch.tensor([[0.5, 0.5]])
potential_fn, theta_transform = ratio_estimator_based_potential(
ratio_estimator, prior, x_o
)

# 3. Evaluate potential at a point
theta_test = torch.tensor([[0.4, 0.6]])
log_prob = potential_fn(theta_test)
"""

def __init__(
self,
ratio_estimator: nn.Module,
Expand Down
Loading