Skip to content

Commit 262d7a9

Browse files
authored
Merge pull request #212 from danielward27/mostly_docs
Docs, version bump, and flip Sandwich args
2 parents d597485 + 3e8c33a commit 262d7a9

11 files changed

Lines changed: 70 additions & 67 deletions

File tree

flowjax/bijections/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .exp import Exp
1010
from .jax_transforms import Scan, Vmap
1111
from .masked_autoregressive import MaskedAutoregressive
12+
from .orthogonal import DiscreteCosine, Householder
1213
from .planar import Planar
1314
from .power import Power
1415
from .rational_quadratic_spline import RationalQuadraticSpline
@@ -26,8 +27,6 @@
2627
Reshape,
2728
Sandwich,
2829
)
29-
from .utils import EmbedCondition, Flip, Identity, Invert, Permute, Reshape, Sandwich
30-
from .orthogonal import Householder, DiscreteCosine
3130

3231
__all__ = [
3332
"AdditiveCondition",

flowjax/bijections/bijection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from equinox import AbstractVar
1616
from jaxtyping import Array, ArrayLike
1717
from paramax import unwrap
18-
import jax
1918

2019
from flowjax.utils import _get_ufunc_signature, arraylike_to_array
2120

flowjax/bijections/chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from collections.abc import Sequence
44

5-
from paramax import AbstractUnwrappable, unwrap
65
import jax.numpy as jnp
76
from jax import Array
7+
from paramax import AbstractUnwrappable, unwrap
88

99
from flowjax.bijections.bijection import AbstractBijection
1010
from flowjax.utils import check_shapes_match, merge_cond_shapes

flowjax/bijections/coupling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import equinox as eqx
99
import jax.nn as jnn
1010
import jax.numpy as jnp
11-
import jax
1211
import paramax
1312
from jaxtyping import PRNGKeyArray
1413

flowjax/bijections/jax_transforms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections.abc import Callable
44

55
import equinox as eqx
6-
from jax import Array
76
import jax.numpy as jnp
87
from jax.lax import scan
98
from jax.tree_util import tree_leaves, tree_map

flowjax/bijections/orthogonal.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,54 @@
1-
from paramax import AbstractUnwrappable, Parameterize
2-
from flowjax.bijections.bijection import AbstractBijection
3-
from jax import Array
1+
"""Orthogonal transformations."""
2+
43
import jax.numpy as jnp
5-
import jax.nn as jnn
4+
from jax import Array
65
from jax.scipy import fft
6+
from jaxtyping import ArrayLike
7+
from paramax import AbstractUnwrappable, Parameterize
8+
9+
from flowjax.bijections.bijection import AbstractBijection
10+
from flowjax.utils import arraylike_to_array
711

812

913
class Householder(AbstractBijection):
10-
"""A Householder reflection bijection.
14+
"""A Householder reflection.
1115
12-
This bijection implements a Householder reflection, which is a linear
13-
transformation that reflects vectors across a hyperplane defined by a normal
16+
A linear transformation reflecting vectors across a hyperplane defined by a normal
1417
vector (params). The transformation is its own inverse and volume-preserving
15-
(determinant = ±1).
18+
(determinant = -1). Given a unit vector :math:`v`, the transformation is
19+
:math:`y = x - 2(x^T v)v`.
20+
21+
It is often desirable to stack multiple such transforms (e.g. up to the
22+
dimensionality of the data):
1623
17-
Given a unit vector v, the transformation is:
18-
x → x - 2(x·v)v
24+
.. doctest::
1925
20-
Attributes:
21-
shape: Shape of the input/output vectors
22-
cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
26+
>>> from flowjax.bijections import Householder, Scan
27+
>>> import jax.random as jr
28+
>>> import equinox as eqx
29+
>>> import jax.numpy as jnp
30+
31+
>>> dim = 5
32+
>>> keys = jr.split(jr.key(0), dim)
33+
>>> householder_stack = Scan(
34+
... eqx.filter_vmap(lambda key: Householder(jr.normal(key, dim)))(keys)
35+
... )
36+
37+
Args:
2338
params: Normal vector defining the reflection hyperplane. The vector is
2439
normalized in the transformation, so scaling params will have no effect
2540
on the bijection.
2641
"""
42+
2743
shape: tuple[int, ...]
2844
unit_vec: Array | AbstractUnwrappable
2945
cond_shape = None
3046

31-
def __init__(self, params: Array):
32-
self.shape = (params.shape[-1],)
47+
def __init__(self, params: ArrayLike):
48+
params = arraylike_to_array(params)
49+
if params.ndim != 1:
50+
raise ValueError("params must be a vector.")
51+
self.shape = params.shape
3352
self.unit_vec = Parameterize(lambda x: x / jnp.linalg.norm(x), params)
3453

3554
def _householder(self, x: Array) -> Array:
@@ -47,11 +66,9 @@ class DiscreteCosine(AbstractBijection):
4766
4867
This bijection applies the DCT or its inverse along a specified axis.
4968
50-
Attributes:
51-
shape: Shape of the input/output arrays
52-
cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
53-
axis: Axis along which to apply the DCT
54-
norm: Normalization method, fixed to 'ortho' to ensure bijectivity
69+
Args:
70+
shape: Shape of the input/output arrays.
71+
axis: Axis along which to apply the DCT.
5572
"""
5673

5774
shape: tuple[int, ...]

flowjax/bijections/softplus.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@
22

33
from typing import ClassVar
44

5-
import jax
65
import jax.numpy as jnp
7-
from jax.nn import softplus, soft_sign
8-
from jaxtyping import Array, ArrayLike
9-
from paramax import AbstractUnwrappable, Parameterize, unwrap
10-
from paramax.utils import inv_softplus
6+
from jax.nn import softplus
117

128
from flowjax.bijections.bijection import AbstractBijection
13-
from flowjax.utils import arraylike_to_array
9+
1410

1511
class SoftPlus(AbstractBijection):
1612
r"""Transforms to positive domain using softplus :math:`y = \log(1 + \exp(x))`."""

flowjax/bijections/utils.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def __init__(self, permutation: Int[Array | np.ndarray, "..."]):
8080
)
8181

8282
def transform_and_log_det(self, x, condition=None):
83-
return x[self.permutation], jnp.array(0.0)
83+
return x[self.permutation], jnp.zeros(())
8484

8585
def inverse_and_log_det(self, y, condition=None):
86-
return y[self.inverse_permutation], jnp.array(0.0)
86+
return y[self.inverse_permutation], jnp.zeros(())
8787

8888

8989
class Flip(AbstractBijection):
@@ -308,36 +308,34 @@ def inverse_and_log_det(self, y, condition=None):
308308

309309

310310
class Sandwich(AbstractBijection):
311-
"""A bijection that composes bijections in a nested structure: g⁻¹ ∘ f ∘ g.
312-
313-
The Sandwich bijection creates a new transformation by "sandwiching" one
314-
bijection between the forward and inverse applications of another. Given
315-
bijections f and g, it computes:
316-
Forward: x → g⁻¹(f(g(x)))
317-
Inverse: y → g⁻¹(f⁻¹(g(y)))
318-
319-
This composition pattern is useful for:
320-
- Creating symmetries in the transformation
321-
- Applying a transformation in a different coordinate system
322-
- Building more complex bijections from simpler ones
323-
324-
Attributes:
325-
shape: Shape of the input/output arrays
326-
cond_shape: Shape of conditional inputs
327-
outer: Transformation g applied first and inverted last
328-
inner: Transformation f applied in the middle
311+
r"""Composes bijections in a nested structure: :math:`g^{-1} \circ f \circ g`.
312+
313+
Creates a new transformation by "sandwiching" one bijection between the forward and
314+
inverse applications of another. Given bijections :math:`f` and :math:`g`, it
315+
computes
316+
317+
- Forward: :math:`y = g^{-1}(f(g(x)))`
318+
- Inverse: :math:`x = g^{-1}(f^{-1}(g(y)))`
319+
320+
This can be used for e.g. creating symmetries in the transformation or to apply a
321+
transformation in a different coordinate system.
322+
323+
Args:
324+
inner: The inner transform.
325+
outer: The outer transform.
329326
"""
327+
330328
shape: tuple[int, ...]
331329
cond_shape: tuple[int, ...] | None
332-
outer: AbstractBijection
333330
inner: AbstractBijection
331+
outer: AbstractBijection
334332

335-
def __init__(self, outer: AbstractBijection, inner: AbstractBijection):
333+
def __init__(self, inner: AbstractBijection, outer: AbstractBijection):
336334
check_shapes_match([outer.shape, inner.shape])
337335
self.cond_shape = merge_cond_shapes([outer.cond_shape, inner.cond_shape])
338336
self.shape = inner.shape
339-
self.outer = outer
340337
self.inner = inner
338+
self.outer = outer
341339

342340
def transform_and_log_det(self, x: Array, condition=None) -> tuple[Array, Array]:
343341
chain = Chain([self.outer, self.inner, Invert(self.outer)])

flowjax/flows.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
Permute,
3333
Planar,
3434
RationalQuadraticSpline,
35+
Sandwich,
3536
Scan,
3637
TriangularAffine,
3738
Vmap,
@@ -290,10 +291,9 @@ def triangular_spline_flow(
290291
) -> Transformed:
291292
"""Triangular spline flow.
292293
293-
A single layer consists where each layer consists of a triangular affine
294-
transformation with weight normalisation, and an elementwise rational quadratic
295-
spline. Tanh is used to constrain to the input to [-1, 1] before spline
296-
transformations.
294+
Each layer consists of a triangular affine transformation with weight normalisation,
295+
and an elementwise rational quadratic spline. Tanh is used to constrain to the input
296+
to [-1, 1] before spline transformations.
297297
298298
Args:
299299
key: Jax random key.
@@ -325,9 +325,7 @@ def make_layer(key):
325325
lambda t: t.triangular, tri_aff, replace_fn=WeightNormalization
326326
)
327327
bijections = [
328-
LeakyTanh(tanh_max_val, (dim,)),
329-
get_splines(),
330-
Invert(LeakyTanh(tanh_max_val, (dim,))),
328+
Sandwich(get_splines(), LeakyTanh(tanh_max_val, (dim,))),
331329
tri_aff,
332330
]
333331

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ license = { file = "LICENSE" }
2323
name = "flowjax"
2424
readme = "README.md"
2525
requires-python = ">=3.10"
26-
version = "17.0.2"
26+
version = "17.1.0"
2727

2828
[project.urls]
2929
repository = "https://github.com/danielward27/flowjax"
@@ -34,7 +34,7 @@ dev = [
3434
"pytest",
3535
"beartype",
3636
"ruff",
37-
"sphinx",
37+
"sphinx <8.2", # TODO due to https://github.com/tox-dev/sphinx-autodoc-typehints/issues/523
3838
"sphinx-book-theme",
3939
"sphinx-copybutton",
4040
"sphinx-autodoc-typehints",

0 commit comments

Comments
 (0)