Skip to content

Commit fb5124e

Browse files
jorobledojanfb
andauthored
feat: implement posterior.to(device) (#1527)
* add to(device) to BasePotential * add to(device) to LikelihoodBasedPotential * add to(device) to PosteriorBasedPotential * remove Optional for prior * add to(device) to RatioBasedPotential * add to(device) to PosteriorScoreBasedPotential * add to(device) to IIDScoreFunction * add to(device) to DirectPosterior * add to(device) to EnsemblePotential and EnsemblePosterior * add to(device) to ImportanceSamplingPosterior * add potentials and posterior testing of to(device) method * add to(device) method to MCMCPosterior * add to(device) method to RejectionPosterior * add to(device) to ScorePosterior and corresponding test * add to(device) to VIPosterior and FIX bug, _prepare_potential is not in this class. * fix if conditions on attributes * fix if condition on posteriors that checks if default_x(x_o) has been stored, to keep it * overcome different types of devices on substraction in transformed_potential * correct idxs that must be on cpu * add sample and log_probs test * put potentials and inits in cpu to calculate map * remove comments * revert changes on map * fix typing for all modified functions and classes * fix naming from score to vector_field * fix type checks and score to vector_field naming * Make sure all tensors are on the same device on VectorFieldPosterior * include test for all VectorFieldInference samples also on device * fix moving time tensor to device, and change of inference_method name * modify assert message * add device attribute to VectorFieldPosterior in initalization * fix typo * Apply suggestions from code review Co-authored-by: Jan <janfb@users.noreply.github.com> * fix typing * fix mnpe device handling * fix marginal trainer device handling * fix gradient ascent device handling; refactor tests * add missinf to(device) to OneDimPriorWrapper * fix failing test on notImplementedError CustomPriorWrapper * fix user utils tests --------- Co-authored-by: Jan <janfb@users.noreply.github.com> Co-authored-by: Jan Boelts <jan.boelts@mailbox.org>
1 parent 83ffb3d commit fb5124e

26 files changed

Lines changed: 668 additions & 121 deletions

sbi/inference/posteriors/base_posterior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
self,
3434
potential_fn: Union[BasePotential, CustomPotential],
3535
theta_transform: Optional[TorchTransform] = None,
36-
device: Optional[str] = None,
36+
device: Optional[Union[str, torch.device]] = None,
3737
x_shape: Optional[torch.Size] = None,
3838
):
3939
"""

sbi/inference/posteriors/direct_posterior.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
posterior_estimator: ConditionalDensityEstimator,
4747
prior: Distribution,
4848
max_sampling_batch_size: int = 10_000,
49-
device: Optional[str] = None,
49+
device: Optional[Union[str, torch.device]] = None,
5050
x_shape: Optional[torch.Size] = None,
5151
enable_transform: bool = True,
5252
):
@@ -67,6 +67,8 @@ def __init__(
6767
# builds it itself. The `potential_fn` and `theta_transform` are used only for
6868
# obtaining the MAP.
6969
check_prior(prior)
70+
self.enable_transform = enable_transform
71+
self.x_shape = x_shape
7072
potential_fn, theta_transform = posterior_estimator_based_potential(
7173
posterior_estimator,
7274
prior,
@@ -81,6 +83,7 @@ def __init__(
8183
x_shape=x_shape,
8284
)
8385

86+
self.device = device
8487
self.prior = prior
8588
self.posterior_estimator = posterior_estimator
8689

@@ -90,6 +93,45 @@ def __init__(
9093
self._purpose = """It samples the posterior network and rejects samples that
9194
lie outside of the prior bounds."""
9295

96+
def to(self, device: Union[str, torch.device]) -> None:
97+
"""Move posterior_estimator, prior and x_o to device.
98+
99+
Changes the device attribute, reinstanciates the
100+
posterior, and resets the default x.
101+
102+
Args:
103+
device: device where to move the posterior to.
104+
"""
105+
self.device = device
106+
if hasattr(self.prior, "to"):
107+
self.prior.to(device) # type: ignore
108+
else:
109+
raise ValueError("""Prior has no attribute to(device).""")
110+
if hasattr(self.posterior_estimator, "to"):
111+
self.posterior_estimator.to(device)
112+
else:
113+
raise ValueError("""Posterior estimator has no attribute to(device).""")
114+
115+
potential_fn, theta_transform = posterior_estimator_based_potential(
116+
self.posterior_estimator,
117+
self.prior,
118+
x_o=None,
119+
enable_transform=self.enable_transform,
120+
)
121+
x_o = None
122+
if hasattr(self, "_x") and (self._x is not None):
123+
x_o = self._x.to(device)
124+
125+
super().__init__(
126+
potential_fn=potential_fn,
127+
theta_transform=theta_transform,
128+
device=device,
129+
x_shape=self.x_shape,
130+
)
131+
# super().__init__ erases the self._x, so we need to set it again
132+
if x_o is not None:
133+
self.set_default_x(x_o)
134+
93135
def sample(
94136
self,
95137
sample_shape: Shape = torch.Size(),

sbi/inference/posteriors/ensemble_posterior.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sbi.inference.potentials.base_potential import BasePotential
1212
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
1313
from sbi.sbi_types import Shape, TorchTransform
14-
from sbi.utils.sbiutils import gradient_ascent
14+
from sbi.utils.sbiutils import gradient_ascent, mcmc_transform
1515
from sbi.utils.torchutils import ensure_theta_batched
1616
from sbi.utils.user_input_checks import process_x
1717

@@ -53,13 +53,15 @@ class EnsemblePosterior(NeuralPosterior):
5353
theta_transform: If passed, this transformation will be applied during the
5454
optimization performed when obtaining the map. It does not affect the
5555
`.sample()` and `.log_prob()` methods.
56+
device: device to host the posterior distribution.
5657
"""
5758

5859
def __init__(
5960
self,
6061
posteriors: List,
6162
weights: Optional[Union[List[float], Tensor]] = None,
6263
theta_transform: Optional[TorchTransform] = None,
64+
device: Optional[Union[str, torch.device]] = None,
6365
):
6466
r"""
6567
Args:
@@ -77,11 +79,34 @@ def __init__(
7779
self.theta_transform = theta_transform
7880
# Take first prior as reference
7981
self.prior = posteriors[0].potential_fn.prior
82+
self.device = device
8083

81-
device = self.ensure_same_device(posteriors)
84+
if self.device is None:
85+
self.device = self.ensure_same_device(posteriors)
8286

87+
self._build_potential_fns()
88+
89+
def to(self, device: Union[str, torch.device]) -> None:
90+
"""Moves each posterior to device.
91+
92+
Prior and weights are also moved to
93+
the specified device.
94+
95+
Args:
96+
device: The device to move the ensemble posterior to.
97+
"""
98+
self.device = device
99+
self._device = device
100+
for i in range(len(self.posteriors)):
101+
self.posteriors[i].to(device)
102+
self.prior.to(device)
103+
self.theta_transform = mcmc_transform(self.prior, device=device)
104+
self._weights.to(device)
105+
self._build_potential_fns()
106+
107+
def _build_potential_fns(self):
83108
potential_fns = []
84-
for posterior in posteriors:
109+
for posterior in self.posteriors:
85110
potential = posterior.potential_fn
86111
potential_fns.append(potential)
87112
# make sure all prior are the same
@@ -94,8 +119,8 @@ def __init__(
94119

95120
super().__init__(
96121
potential_fn=potential_fn,
97-
theta_transform=theta_transform,
98-
device=device,
122+
theta_transform=self.theta_transform,
123+
device=self.device,
99124
)
100125

101126
def ensure_same_device(self, posteriors: List) -> str:
@@ -405,7 +430,7 @@ def __init__(
405430
weights: Tensor,
406431
prior: Distribution,
407432
x_o: Optional[Tensor],
408-
device: str = "cpu",
433+
device: Union[str, torch.device] = "cpu",
409434
):
410435
r"""
411436
Args:
@@ -419,6 +444,23 @@ def __init__(
419444
self.potential_fns = potential_fns
420445
super().__init__(prior, x_o, device)
421446

447+
def to(self, device: Union[str, torch.device]) -> None:
448+
"""
449+
Moves the ensemble potentials, the prior, the weights and x_o to
450+
451+
the specified device.
452+
453+
Args:
454+
device: The device to move the ensemble potential to.
455+
"""
456+
self.device = device
457+
for i in range(len(self.potential_fns)):
458+
self.potential_fns[i].to(device)
459+
self._weights = self._weights.to(device)
460+
self.prior.to(device) # type: ignore
461+
if self._x_o is not None:
462+
self._x_o = self._x_o.to(device)
463+
422464
def allow_iid_x(self) -> bool:
423465
# in case there is different kinds of posteriors, this will produce an error
424466
# in `set_x()`

sbi/inference/posteriors/importance_posterior.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sbi.samplers.importance.importance_sampling import importance_sample
1212
from sbi.samplers.importance.sir import sampling_importance_resampling
1313
from sbi.sbi_types import Shape, TorchTransform
14+
from sbi.utils.sbiutils import mcmc_transform
1415
from sbi.utils.torchutils import ensure_theta_batched
1516

1617

@@ -32,7 +33,7 @@ def __init__(
3233
method: str = "sir",
3334
oversampling_factor: int = 32,
3435
max_sampling_batch_size: int = 10_000,
35-
device: Optional[str] = None,
36+
device: Optional[Union[str, torch.device]] = None,
3637
x_shape: Optional[torch.Size] = None,
3738
):
3839
"""
@@ -65,6 +66,7 @@ def __init__(
6566
self.proposal = proposal
6667
self._normalization_constant = None
6768
self.method = method
69+
self.theta_transform = theta_transform
6870

6971
self.oversampling_factor = oversampling_factor
7072
self.max_sampling_batch_size = max_sampling_batch_size
@@ -74,6 +76,34 @@ def __init__(
7476
"posterior and can evaluate the _unnormalized_ posterior density with "
7577
".log_prob()."
7678
)
79+
self.x_shape = x_shape
80+
81+
def to(self, device: Union[str, torch.device]) -> None:
82+
"""
83+
Move the potential, the proposal and x_o to a new device.
84+
85+
It also reinstantiates the posterior with the new device.
86+
87+
Args:
88+
device: Device on which to move the posterior to.
89+
"""
90+
self.device = device
91+
self.potential_fn.to(device) # type: ignore
92+
self.proposal.to(device)
93+
x_o = None
94+
if hasattr(self, "_x") and (self._x is not None):
95+
x_o = self._x.to(device)
96+
97+
self.theta_transform = mcmc_transform(self.proposal, device=device)
98+
super().__init__(
99+
self.potential_fn,
100+
theta_transform=self.theta_transform,
101+
device=device,
102+
x_shape=self.x_shape,
103+
)
104+
# super().__init__ erases the self._x, so we need to set it again
105+
if x_o is not None:
106+
self.set_default_x(x_o)
77107

78108
def log_prob(
79109
self,
@@ -211,9 +241,9 @@ def sample_batched(
211241
show_progress_bars: bool = True,
212242
) -> Tensor:
213243
raise NotImplementedError(
214-
"Batched sampling is not implemented for ImportanceSamplingPosterior. "
215-
"Alternatively you can use `sample` in a loop "
216-
"[posterior.sample(theta, x_o) for x_o in x]."
244+
"Batched sampling is not implemented for ImportanceSamplingPosterior. \
245+
Alternatively you can use `sample` in a loop \
246+
[posterior.sample(theta, x_o) for x_o in x]."
217247
)
218248

219249
def _importance_sample(

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
init_strategy_num_candidates: Optional[int] = None,
6060
num_workers: int = 1,
6161
mp_context: str = "spawn",
62-
device: Optional[str] = None,
62+
device: Optional[Union[str, torch.device]] = None,
6363
x_shape: Optional[torch.Size] = None,
6464
):
6565
"""
@@ -136,6 +136,7 @@ def __init__(
136136
self._posterior_sampler = None
137137
# Hardcode parameter name to reduce clutter kwargs.
138138
self.param_name = "theta"
139+
self.x_shape = x_shape
139140

140141
if init_strategy_num_candidates is not None:
141142
warn(
@@ -155,6 +156,34 @@ def __init__(
155156
"can evaluate the _unnormalized_ posterior density with .log_prob()."
156157
)
157158

159+
def to(self, device: Union[str, torch.device]) -> None:
160+
"""Moves potential_fn, proposal, x_o and theta_transform to the
161+
162+
specified device. Reinstantiates the posterior and resets the default x_o.
163+
164+
Args:
165+
device: Device to move the posterior to.
166+
"""
167+
self.device = device
168+
self.potential_fn.to(device) # type: ignore
169+
self.proposal.to(device)
170+
x_o = None
171+
if hasattr(self, "_x") and (self._x is not None):
172+
x_o = self._x.to(device)
173+
174+
self.theta_transform = mcmc_transform(self.proposal, device=device)
175+
176+
super().__init__(
177+
self.potential_fn,
178+
theta_transform=self.theta_transform,
179+
device=device,
180+
x_shape=self.x_shape,
181+
)
182+
# super().__init__ erases the self._x, so we need to set it again
183+
if x_o is not None:
184+
self.set_default_x(x_o)
185+
self.potential_ = self._prepare_potential(self.method)
186+
158187
@property
159188
def mcmc_method(self) -> str:
160189
"""Returns MCMC method."""
@@ -266,9 +295,9 @@ def sample(
266295
)
267296
if init_strategy_num_candidates is not None:
268297
warn(
269-
"Passing `init_strategy_num_candidates` is deprecated as of sbi "
270-
"v0.19.0. Instead, use e.g., "
271-
f"`init_strategy_parameters={'num_candidate_samples': 1000}`",
298+
f"Passing `init_strategy_num_candidates` is deprecated as of sbi \
299+
v0.19.0. Instead, use e.g., \
300+
`init_strategy_parameters={'num_candidate_samples': 1000}`",
272301
stacklevel=2,
273302
)
274303
self.init_strategy_parameters["num_candidate_samples"] = (

0 commit comments

Comments
 (0)