Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
80e7e82
add to(device) to BasePotential
jorobledo Mar 21, 2025
7b1d301
add to(device) to LikelihoodBasedPotential
jorobledo Mar 21, 2025
eb28095
add to(device) to PosteriorBasedPotential
jorobledo Mar 21, 2025
71d4d3b
remove Optional for prior
jorobledo Mar 21, 2025
cf58816
add to(device) to RatioBasedPotential
jorobledo Mar 21, 2025
cb5955e
add to(device) to PosteriorScoreBasedPotential
jorobledo Mar 21, 2025
66e5abe
add to(device) to IIDScoreFunction
jorobledo Mar 21, 2025
3219983
add to(device) to DirectPosterior
jorobledo Mar 21, 2025
aeddfe7
add to(device) to EnsemblePotential and EnsemblePosterior
jorobledo Mar 21, 2025
291d2ec
add to(device) to ImportanceSamplingPosterior
jorobledo Mar 21, 2025
cc40018
add potentials and posterior testing of to(device) method
jorobledo Mar 21, 2025
32d6fb8
add to(device) method to MCMCPosterior
jorobledo Mar 21, 2025
4e62f42
add to(device) method to RejectionPosterior
jorobledo Mar 21, 2025
12f41a4
add to(device) to ScorePosterior and corresponding test
jorobledo Mar 21, 2025
2f9730b
add to(device) to VIPosterior and FIX bug, _prepare_potential is not …
jorobledo Mar 21, 2025
b992525
Merge branch 'sbi-dev:main' into posterior_device
jorobledo Mar 21, 2025
af60607
fix if conditions on attributes
jorobledo Mar 21, 2025
8a6a23a
fix if condition on posteriors that checks if default_x(x_o) has been…
jorobledo Mar 21, 2025
56ce814
overcome different types of devices on substraction in transformed_po…
jorobledo Mar 21, 2025
98234e3
correct idxs that must be on cpu
jorobledo Mar 21, 2025
38cd395
add sample and log_probs test
jorobledo Mar 21, 2025
5f5f5a0
Merge branch 'posterior_device' of github.com:jorobledo/sbi into post…
jorobledo Mar 21, 2025
fa09e15
put potentials and inits in cpu to calculate map
jorobledo Mar 21, 2025
9451635
remove comments
jorobledo Mar 21, 2025
7ee3ece
revert changes on map
jorobledo Mar 21, 2025
92e1293
fix typing for all modified functions and classes
jorobledo Mar 26, 2025
cae1580
Merge branch 'main' of github.com:sbi-dev/sbi into posterior_device
jorobledo Mar 26, 2025
a6c913e
fix naming from score to vector_field
jorobledo Mar 26, 2025
e120a92
fix type checks and score to vector_field naming
jorobledo Mar 26, 2025
4c4d3d9
Make sure all tensors are on the same device on VectorFieldPosterior
jorobledo Mar 27, 2025
72bf7f2
include test for all VectorFieldInference samples also on device
jorobledo Mar 27, 2025
33926de
fix moving time tensor to device, and change of inference_method name
jorobledo Mar 27, 2025
412a417
modify assert message
jorobledo Mar 27, 2025
a53711b
Merge branch 'sbi-dev:main' into posterior_device
jorobledo Mar 27, 2025
f8e4497
add device attribute to VectorFieldPosterior in initalization
jorobledo Mar 27, 2025
d693840
Merge branch 'posterior_device' of github.com:jorobledo/sbi into post…
jorobledo Mar 27, 2025
8878a81
fix typo
jorobledo Mar 27, 2025
4c3988d
Apply suggestions from code review
jorobledo Mar 27, 2025
fa9ee81
fix typing
jorobledo Mar 27, 2025
b7b9ed3
fix mnpe device handling
janfb Mar 27, 2025
c3ad3df
fix marginal trainer device handling
janfb Mar 27, 2025
a15d58b
fix gradient ascent device handling; refactor tests
janfb Mar 27, 2025
d84293d
add missinf to(device) to OneDimPriorWrapper
jorobledo Mar 27, 2025
5848e5a
fix failing test on notImplementedError CustomPriorWrapper
jorobledo Mar 27, 2025
f2642b5
fix user utils tests
janfb Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self,
potential_fn: Union[BasePotential, CustomPotential],
theta_transform: Optional[TorchTransform] = None,
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
):
"""
Expand Down
44 changes: 43 additions & 1 deletion sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
posterior_estimator: ConditionalDensityEstimator,
prior: Distribution,
max_sampling_batch_size: int = 10_000,
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
enable_transform: bool = True,
):
Expand All @@ -67,6 +67,8 @@
# builds it itself. The `potential_fn` and `theta_transform` are used only for
# obtaining the MAP.
check_prior(prior)
self.enable_transform = enable_transform
self.x_shape = x_shape
potential_fn, theta_transform = posterior_estimator_based_potential(
posterior_estimator,
prior,
Expand All @@ -81,6 +83,7 @@
x_shape=x_shape,
)

self.device = device
self.prior = prior
self.posterior_estimator = posterior_estimator

Expand All @@ -90,6 +93,45 @@
self._purpose = """It samples the posterior network and rejects samples that
lie outside of the prior bounds."""

def to(self, device: Union[str, torch.device]) -> None:
"""Move posterior_estimator, prior and x_o to device.

Changes the device attribute, reinstanciates the
posterior, and resets the default x.

Args:
device: device where to move the posterior to.
"""
self.device = device
if hasattr(self.prior, "to"):
self.prior.to(device) # type: ignore

Check warning on line 107 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L105-L107

Added lines #L105 - L107 were not covered by tests
else:
raise ValueError("""Prior has no attribute to(device).""")
if hasattr(self.posterior_estimator, "to"):
self.posterior_estimator.to(device)

Check warning on line 111 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L109-L111

Added lines #L109 - L111 were not covered by tests
else:
raise ValueError("""Posterior estimator has no attribute to(device).""")

Check warning on line 113 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L113

Added line #L113 was not covered by tests

potential_fn, theta_transform = posterior_estimator_based_potential(

Check warning on line 115 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L115

Added line #L115 was not covered by tests
self.posterior_estimator,
self.prior,
x_o=None,
enable_transform=self.enable_transform,
)
x_o = None
if hasattr(self, "_x") and (self._x is not None):
x_o = self._x.to(device)

Check warning on line 123 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L121-L123

Added lines #L121 - L123 were not covered by tests

super().__init__(

Check warning on line 125 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L125

Added line #L125 was not covered by tests
potential_fn=potential_fn,
theta_transform=theta_transform,
device=device,
x_shape=self.x_shape,
)
# super().__init__ erase the self._x, so we need to set it again
Comment thread
jorobledo marked this conversation as resolved.
Outdated
if x_o is not None:
self.set_default_x(x_o)

Check warning on line 133 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L132-L133

Added lines #L132 - L133 were not covered by tests

def sample(
self,
sample_shape: Shape = torch.Size(),
Expand Down
54 changes: 48 additions & 6 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils.sbiutils import gradient_ascent
from sbi.utils.sbiutils import gradient_ascent, mcmc_transform
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.user_input_checks import process_x

Expand Down Expand Up @@ -53,13 +53,15 @@
theta_transform: If passed, this transformation will be applied during the
optimization performed when obtaining the map. It does not affect the
`.sample()` and `.log_prob()` methods.
device: device to host the posterior distribution.
"""

def __init__(
self,
posteriors: List,
weights: Optional[Union[List[float], Tensor]] = None,
theta_transform: Optional[TorchTransform] = None,
device: Optional[Union[str, torch.device]] = None,
):
r"""
Args:
Expand All @@ -77,11 +79,34 @@
self.theta_transform = theta_transform
# Take first prior as reference
self.prior = posteriors[0].potential_fn.prior
self.device = device

device = self.ensure_same_device(posteriors)
if self.device is None:
self.device = self.ensure_same_device(posteriors)

self._build_potential_fns()

def to(self, device: Union[str, torch.device]) -> None:
"""Moves each posterior to device.

Prior and weights are also moved to
the specified device.

Args:
device: The device to move the ensemble posterior to.
"""
self.device = device
self._device = device
for i in range(len(self.posteriors)):
self.posteriors[i].to(device)
self.prior.to(device)
self.theta_transform = mcmc_transform(self.prior, device=device)
self._weights.to(device)
self._build_potential_fns()

Check warning on line 105 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L98-L105

Added lines #L98 - L105 were not covered by tests

def _build_potential_fns(self):
potential_fns = []
for posterior in posteriors:
for posterior in self.posteriors:
potential = posterior.potential_fn
potential_fns.append(potential)
# make sure all prior are the same
Expand All @@ -94,8 +119,8 @@

super().__init__(
potential_fn=potential_fn,
theta_transform=theta_transform,
device=device,
theta_transform=self.theta_transform,
device=self.device,
)

def ensure_same_device(self, posteriors: List) -> str:
Expand Down Expand Up @@ -405,7 +430,7 @@
weights: Tensor,
prior: Distribution,
x_o: Optional[Tensor],
device: str = "cpu",
device: Union[str, torch.device] = "cpu",
):
r"""
Args:
Expand All @@ -419,6 +444,23 @@
self.potential_fns = potential_fns
super().__init__(prior, x_o, device)

def to(self, device):
Comment thread
jorobledo marked this conversation as resolved.
Outdated
"""
Moves the ensemble potentials, the prior, the weights and x_o to

the specified device.

Args:
device: The device to move the ensemble potential to.
"""
self.device = device
for i in range(len(self.potential_fns)):
self.potential_fns[i].to(device)
self._weights = self._weights.to(device)
self.prior.to(device) # type: ignore
if self._x_o is not None:
self._x_o = self._x_o.to(device)

Check warning on line 462 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L456-L462

Added lines #L456 - L462 were not covered by tests

def allow_iid_x(self) -> bool:
# in case there is different kinds of posteriors, this will produce an error
# in `set_x()`
Expand Down
38 changes: 34 additions & 4 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sbi.samplers.importance.importance_sampling import importance_sample
from sbi.samplers.importance.sir import sampling_importance_resampling
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils.sbiutils import mcmc_transform
from sbi.utils.torchutils import ensure_theta_batched


Expand All @@ -32,7 +33,7 @@
method: str = "sir",
oversampling_factor: int = 32,
max_sampling_batch_size: int = 10_000,
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
):
"""
Expand Down Expand Up @@ -65,6 +66,7 @@
self.proposal = proposal
self._normalization_constant = None
self.method = method
self.theta_transform = theta_transform

self.oversampling_factor = oversampling_factor
self.max_sampling_batch_size = max_sampling_batch_size
Expand All @@ -74,6 +76,34 @@
"posterior and can evaluate the _unnormalized_ posterior density with "
".log_prob()."
)
self.x_shape = x_shape

def to(self, device):
Comment thread
jorobledo marked this conversation as resolved.
Outdated
"""
Move the potential, the proposal and x_o to a new device.

It also reinstansiates the posterior with the new device.
Comment thread
jorobledo marked this conversation as resolved.
Outdated

Args:
device: Device on which to move the posterior to.
"""
self.device = device
self.potential_fn.to(device) # type: ignore
self.proposal.to(device)
x_o = None
if hasattr(self, "_x") and (self._x is not None):
x_o = self._x.to(device)

Check warning on line 95 in sbi/inference/posteriors/importance_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/importance_posterior.py#L90-L95

Added lines #L90 - L95 were not covered by tests

self.theta_transform = mcmc_transform(self.proposal, device=device)
super().__init__(

Check warning on line 98 in sbi/inference/posteriors/importance_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/importance_posterior.py#L97-L98

Added lines #L97 - L98 were not covered by tests
self.potential_fn,
theta_transform=self.theta_transform,
device=device,
x_shape=self.x_shape,
)
# super().__init__ erase the self._x, so we need to set it again
Comment thread
jorobledo marked this conversation as resolved.
Outdated
if x_o is not None:
self.set_default_x(x_o)

Check warning on line 106 in sbi/inference/posteriors/importance_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/importance_posterior.py#L105-L106

Added lines #L105 - L106 were not covered by tests

def log_prob(
self,
Expand Down Expand Up @@ -211,9 +241,9 @@
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(
"Batched sampling is not implemented for ImportanceSamplingPosterior. "
"Alternatively you can use `sample` in a loop "
"[posterior.sample(theta, x_o) for x_o in x]."
"Batched sampling is not implemented for ImportanceSamplingPosterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)

def _importance_sample(
Expand Down
37 changes: 33 additions & 4 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
init_strategy_num_candidates: Optional[int] = None,
num_workers: int = 1,
mp_context: str = "spawn",
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
):
"""
Expand Down Expand Up @@ -136,6 +136,7 @@
self._posterior_sampler = None
# Hardcode parameter name to reduce clutter kwargs.
self.param_name = "theta"
self.x_shape = x_shape

if init_strategy_num_candidates is not None:
warn(
Expand All @@ -155,6 +156,34 @@
"can evaluate the _unnormalized_ posterior density with .log_prob()."
)

def to(self, device: Union[str, torch.device]) -> None:
"""Moves potential_fn, proposal, x_o and theta_transform to the

speficied device. Reinstanciates the posterior and resets the default x_o.
Comment thread
jorobledo marked this conversation as resolved.
Outdated

Args:
device: Device to move the posterior to.
"""
self.device = device
self.potential_fn.to(device) # type: ignore
self.proposal.to(device)
x_o = None
if hasattr(self, "_x") and (self._x is not None):
x_o = self._x.to(device)

Check warning on line 172 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L167-L172

Added lines #L167 - L172 were not covered by tests

self.theta_transform = mcmc_transform(self.proposal, device=device)

Check warning on line 174 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L174

Added line #L174 was not covered by tests

super().__init__(

Check warning on line 176 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L176

Added line #L176 was not covered by tests
self.potential_fn,
theta_transform=self.theta_transform,
device=device,
x_shape=self.x_shape,
)
# super().__init__ erase the self._x, so we need to set it again
Comment thread
jorobledo marked this conversation as resolved.
Outdated
if x_o is not None:
self.set_default_x(x_o)
self.potential_ = self._prepare_potential(self.method)

Check warning on line 185 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L183-L185

Added lines #L183 - L185 were not covered by tests

@property
def mcmc_method(self) -> str:
"""Returns MCMC method."""
Expand Down Expand Up @@ -266,9 +295,9 @@
)
if init_strategy_num_candidates is not None:
warn(
"Passing `init_strategy_num_candidates` is deprecated as of sbi "
"v0.19.0. Instead, use e.g., "
f"`init_strategy_parameters={'num_candidate_samples': 1000}`",
f"Passing `init_strategy_num_candidates` is deprecated as of sbi \
v0.19.0. Instead, use e.g., \
`init_strategy_parameters={'num_candidate_samples': 1000}`",
stacklevel=2,
)
self.init_strategy_parameters["num_candidate_samples"] = (
Expand Down
33 changes: 31 additions & 2 deletions sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sbi.inference.potentials.base_potential import BasePotential
Comment thread
jorobledo marked this conversation as resolved.
Outdated
from sbi.samplers.rejection.rejection import rejection_sample
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils import mcmc_transform
from sbi.utils.torchutils import ensure_theta_batched


Expand All @@ -24,14 +25,14 @@

def __init__(
self,
potential_fn: Union[Callable, BasePotential],
potential_fn: Union[Callable, BasePotential], # type: ignore
Comment thread
jorobledo marked this conversation as resolved.
Outdated
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
max_sampling_batch_size: int = 10_000,
num_samples_to_find_max: int = 10_000,
num_iter_to_find_max: int = 100,
m: float = 1.2,
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
):
"""
Expand Down Expand Up @@ -64,12 +65,40 @@
self.num_samples_to_find_max = num_samples_to_find_max
self.num_iter_to_find_max = num_iter_to_find_max
self.m = m
self.x_shape = x_shape

self._purpose = (
"It provides rejection sampling to .sample() from the posterior and "
"can evaluate the _unnormalized_ posterior density with .log_prob()."
)

def to(self, device: Union[str, torch.device]) -> None:
"""
Move potential fucntion, proposal and x_o to the device.

This method reinstanciate the posterior and resets the default x_o
Comment thread
jorobledo marked this conversation as resolved.
Outdated

Args:
device: The device to move the posterior to.
"""
self.device = device
self.potential_fn.to(device) # type: ignore
self.proposal.to(device)
x_o = None
if hasattr(self, "_x") and (self._x is not None):
x_o = self._x.to(device)

Check warning on line 89 in sbi/inference/posteriors/rejection_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/rejection_posterior.py#L84-L89

Added lines #L84 - L89 were not covered by tests

self.theta_transform = mcmc_transform(self.proposal, device=device)
super().__init__(

Check warning on line 92 in sbi/inference/posteriors/rejection_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/rejection_posterior.py#L91-L92

Added lines #L91 - L92 were not covered by tests
self.potential_fn,
theta_transform=self.theta_transform,
device=device,
x_shape=self.x_shape,
)
# super().__init__ erase the self._x, so we need to set it again
Comment thread
jorobledo marked this conversation as resolved.
Outdated
if x_o is not None:
self.set_default_x(x_o)

Check warning on line 100 in sbi/inference/posteriors/rejection_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/rejection_posterior.py#L99-L100

Added lines #L99 - L100 were not covered by tests

def log_prob(
self, theta: Tensor, x: Optional[Tensor] = None, track_gradients: bool = False
) -> Tensor:
Expand Down
Loading