1- from paramax import AbstractUnwrappable , Parameterize
2- from flowjax .bijections .bijection import AbstractBijection
3- from jax import Array
1+ """Orthogonal transformations."""
2+
43import jax .numpy as jnp
5- import jax . nn as jnn
4+ from jax import Array
65from jax .scipy import fft
6+ from jaxtyping import ArrayLike
7+ from paramax import AbstractUnwrappable , Parameterize
8+
9+ from flowjax .bijections .bijection import AbstractBijection
10+ from flowjax .utils import arraylike_to_array
711
812
913class Householder (AbstractBijection ):
10- """A Householder reflection bijection .
14+ """A Householder reflection.
1115
12- This bijection implements a Householder reflection, which is a linear
13- transformation that reflects vectors across a hyperplane defined by a normal
16+ A linear transformation reflecting vectors across a hyperplane defined by a normal
1417 vector (params). The transformation is its own inverse and volume-preserving
15- (determinant = ±1).
18+ (determinant = -1). Given a unit vector :math:`v`, the transformation is
19+ :math:`y = x - 2(x^T v)v`.
20+
21+ It is often desirable to stack multiple such transforms (e.g. up to the
22+ dimensionality of the data):
1623
17- Given a unit vector v, the transformation is:
18- x → x - 2(x·v)v
24+ .. doctest::
1925
20- Attributes:
21- shape: Shape of the input/output vectors
22- cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
26+ >>> from flowjax.bijections import Householder, Scan
27+ >>> import jax.random as jr
28+ >>> import equinox as eqx
29+ >>> import jax.numpy as jnp
30+
31+ >>> dim = 5
32+ >>> keys = jr.split(jr.key(0), dim)
33+ >>> householder_stack = Scan(
34+ ... eqx.filter_vmap(lambda key: Householder(jr.normal(key, dim)))(keys)
35+ ... )
36+
37+ Args:
2338 params: Normal vector defining the reflection hyperplane. The vector is
2439 normalized in the transformation, so scaling params will have no effect
2540 on the bijection.
2641 """
42+
2743 shape : tuple [int , ...]
2844 unit_vec : Array | AbstractUnwrappable
2945 cond_shape = None
3046
31- def __init__ (self , params : Array ):
32- self .shape = (params .shape [- 1 ],)
47+ def __init__ (self , params : ArrayLike ):
48+ params = arraylike_to_array (params )
49+ if params .ndim != 1 :
50+ raise ValueError ("params must be a vector." )
51+ self .shape = params .shape
3352 self .unit_vec = Parameterize (lambda x : x / jnp .linalg .norm (x ), params )
3453
3554 def _householder (self , x : Array ) -> Array :
@@ -47,11 +66,9 @@ class DiscreteCosine(AbstractBijection):
4766
4867 This bijection applies the DCT or its inverse along a specified axis.
4968
50- Attributes:
51- shape: Shape of the input/output arrays
52- cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
53- axis: Axis along which to apply the DCT
54- norm: Normalization method, fixed to 'ortho' to ensure bijectivity
69+ Args:
70+ shape: Shape of the input/output arrays.
71+ axis: Axis along which to apply the DCT.
5572 """
5673
5774 shape : tuple [int , ...]
0 commit comments