Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@
from pytensor.ifelse import IfElse, ifelse
from pytensor.scalar import Switch
from pytensor.scalar import switch as scalar_switch
from pytensor.scalar.basic import GE, GT, LE, LT, Mul
from pytensor.tensor.basic import Join, MakeVector, switch
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
local_rv_size_lift,
Expand Down Expand Up @@ -80,7 +83,9 @@
measurable_ir_rewrites_db,
subtensor_ops,
)
from pymc.logprob.transforms import MeasurableTransform
from pymc.logprob.utils import (
CheckParameterValue,
check_potential_measurability,
filter_measurable_variables,
get_related_valued_nodes,
Expand Down Expand Up @@ -407,6 +412,80 @@ class MeasurableSwitchMixture(MeasurableElemwise):
measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch)


class MeasurableLeakyReLUSwitch(MeasurableElemwise):
Comment thread
ricardoV94 marked this conversation as resolved.
Outdated
"""A placeholder for leaky-ReLU graphs built via `switch(x > 0, x, a * x)`.

this is an invertible, piecewise-linear transform of a single continuous measurable variable.
"""

valid_scalar_types = (Switch,)


measurable_leaky_relu_switch = MeasurableLeakyReLUSwitch(scalar_switch)


def _is_x_positive_condition(cond: TensorVariable, x: TensorVariable) -> bool:
if cond.owner is None:
return False
if not isinstance(cond.owner.op, Elemwise):
return False
scalar_op = cond.owner.op.scalar_op
if not isinstance(scalar_op, GT | GE | LT | LE):
return False

left, right = cond.owner.inputs

def _is_zero(v: TensorVariable) -> bool:
try:
return pt.get_underlying_scalar_constant_value(v) == 0
except NotScalarConstantError:
return False

# x > 0 or x >= 0
if left is x and _is_zero(right) and isinstance(scalar_op, GT | GE):
return True
# 0 < x or 0 <= x
if right is x and _is_zero(left) and isinstance(scalar_op, LT | LE):
return True
return False


def _extract_leaky_relu_slope(
neg_branch: TensorVariable, x: TensorVariable
) -> TensorVariable | None:
"""Extract slope `a` from `neg_branch` assuming it represents `a * x`.

supports both plain `Elemwise(Mul)` and `MeasurableTransform` scale rewrites.
"""
if neg_branch is x:
return pt.constant(1.0)

if neg_branch.owner is None:
return None

# handle case where `a * x` was already rewritten into a measurable scale transform
if isinstance(neg_branch.owner.op, MeasurableTransform):
Comment thread
ricardoV94 marked this conversation as resolved.
Outdated
op = neg_branch.owner.op
if not isinstance(op.scalar_op, Mul):
return None
# MeasurableTransform takes (measurable_input, scale)
if len(neg_branch.owner.inputs) != 2:
return None
if neg_branch.owner.inputs[op.measurable_input_idx] is not x:
return None
scale = neg_branch.owner.inputs[1 - op.measurable_input_idx]
return cast(TensorVariable, scale)

# plain multiplication
if isinstance(neg_branch.owner.op, Elemwise) and isinstance(neg_branch.owner.op.scalar_op, Mul):
left, right = neg_branch.owner.inputs
if left is x:
return cast(TensorVariable, right)
if right is x:
return cast(TensorVariable, left)
return None


@node_rewriter([switch])
def find_measurable_switch_mixture(fgraph, node):
if isinstance(node.op, MeasurableOp):
Expand All @@ -431,6 +510,51 @@ def find_measurable_switch_mixture(fgraph, node):
return [measurable_switch_mixture(switch_cond, *components)]


@node_rewriter([switch])
def find_measurable_leaky_relu_switch(fgraph, node):
"""Detect `switch(x > 0, x, a * x)` and replace it by a measurable op.

This enables a change-of-variables logprob derivation instead of treating it as a mixture.
Comment thread
ricardoV94 marked this conversation as resolved.
Outdated
"""
if isinstance(node.op, MeasurableOp):
return None

cond, pos_branch, neg_branch = node.inputs

# we only mark the switch measurable once both branches are already measurable.
# so, the switch logprob can simply gate between branch logps (delegating inversion/Jacobian details to each branch).
if set(filter_measurable_variables([pos_branch, neg_branch])) != {pos_branch, neg_branch}:
return None

if not filter_measurable_variables([pos_branch]):
Comment thread
ricardoV94 marked this conversation as resolved.
Outdated
return None
x = cast(TensorVariable, pos_branch)

if x.type.dtype.startswith("int"):
return None

if x.type.broadcastable != node.outputs[0].type.broadcastable:
return None

if not _is_x_positive_condition(cast(TensorVariable, cond), x):
return None

a = _extract_leaky_relu_slope(cast(TensorVariable, neg_branch), x)
if a is None:
return None

if check_potential_measurability([a]):
return None
Comment thread
ricardoV94 marked this conversation as resolved.
Outdated

return [
measurable_leaky_relu_switch(
cast(TensorVariable, cond),
x,
cast(TensorVariable, neg_branch),
)
]


@_logprob.register(MeasurableSwitchMixture)
def logprob_switch_mixture(op, values, switch_cond, component_true, component_false, **kwargs):
[value] = values
Expand All @@ -442,6 +566,30 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa
)


@_logprob.register(MeasurableLeakyReLUSwitch)
def logprob_leaky_relu_switch(op, values, cond, x, neg_branch, **kwargs):
(value,) = values

a = _extract_leaky_relu_slope(cast(TensorVariable, neg_branch), cast(TensorVariable, x))
if a is None:
raise NotImplementedError("Could not extract leaky-ReLU slope")

# enforce `a > 0` at runtime to ensure invertibility and to make the branch selection predicate depend only on the observed value.
a_is_positive = pt.all(pt.gt(a, 0))

# for `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`: true branch -> value > 0, false branch -> value <= 0.
value_implies_true_branch = pt.gt(value, 0)

logp_expr = pt.switch(
value_implies_true_branch,
_logprob_helper(x, value, **kwargs),
_logprob_helper(neg_branch, value, **kwargs),
)

# attach the parameter check to the returned expression so it can't be optimized away.
return CheckParameterValue("leaky_relu slope > 0")(logp_expr, a_is_positive)


measurable_ir_rewrites_db.register(
"find_measurable_index_mixture",
find_measurable_index_mixture,
Expand All @@ -456,6 +604,13 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa
"mixture",
)

measurable_ir_rewrites_db.register(
"find_measurable_leaky_relu_switch",
find_measurable_leaky_relu_switch,
"basic",
"transform",
)


class MeasurableIfElse(MeasurableOp, IfElse):
"""Measurable subclass of IfElse operator."""
Expand Down
54 changes: 54 additions & 0 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from pytensor.graph.basic import equal_computations

import pymc as pm

from pymc.distributions.continuous import Cauchy, ChiSquared
from pymc.distributions.discrete import Bernoulli
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
Expand Down Expand Up @@ -219,6 +221,7 @@ def test_exp_transform_rv():
logp_fn(y_val),
sp.stats.lognorm(s=1).logpdf(y_val),
)

np.testing.assert_almost_equal(
logcdf_fn(y_val),
sp.stats.lognorm(s=1).logcdf(y_val),
Expand All @@ -229,6 +232,57 @@ def test_exp_transform_rv():
)


def test_leaky_relu_switch_logp_scalar():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests should be moved to a test_switch.py file

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests should be moved to a test_switch.py file

oh right, that's my bad

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.

sure

a = 0.5
x = pm.Normal.dist(mu=0, sigma=1)
y = pm.math.switch(x > 0, x, a * x)

v_pos = 1.2
np.testing.assert_allclose(
pm.logp(y, v_pos, warn_rvs=False).eval(),
pm.logp(x, v_pos, warn_rvs=False).eval(),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pm.logp(y, v_pos, warn_rvs=False).eval(),
pm.logp(x, v_pos, warn_rvs=False).eval(),
pm.logp(y, v_pos).eval(),
pm.logp(x, v_pos).eval(),

)

v_neg = -2.0
np.testing.assert_allclose(
pm.logp(y, v_neg, warn_rvs=False).eval(),
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're testing with two values, define the logp variable once (or compile a function with the logp as output once), and reuse it. That will avoid duplicated logp inference calls.

pm.logp(x, v_neg / a, warn_rvs=False).eval() - np.log(a),
)

# boundary point (measure-zero for continuous RVs): should still produce a finite logp
assert np.isfinite(pm.logp(y, 0.0, warn_rvs=False).eval())


def test_leaky_relu_switch_logp_vectorized():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use vectorized in the first test, and remove this. It's not conceputally that different

a = 0.5
x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
y = pm.math.switch(x > 0, x, a * x)

v = np.array([-2.0, 0.0, 1.5])
expected = pm.logp(x, np.where(v > 0, v, v / a), warn_rvs=False).eval() + np.where(
v > 0, 0.0, -np.log(a)
)
np.testing.assert_allclose(pm.logp(y, v, warn_rvs=False).eval(), expected)


def test_leaky_relu_switch_logp_symbolic_slope_checks_positive():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use symbolic a in the first test and test the error there. Still a pretty straightforward test

a = pt.scalar("a")
x = pm.Normal.dist(mu=0, sigma=1)
y = pm.math.switch(x > 0, x, a * x)

# positive slope passes
res = pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.5})
expected = pm.logp(x, -1.0 / 0.5, warn_rvs=False).eval() - np.log(0.5)
np.testing.assert_allclose(res, expected)

# non pos slope raises
with pytest.raises(ParameterValueError, match="leaky_relu slope > 0"):
pm.logp(y, -1.0, warn_rvs=False).eval({a: -0.5})

with pytest.raises(ParameterValueError, match="leaky_relu slope > 0"):
pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.0})


def test_log_transform_rv():
base_rv = pt.random.lognormal(0, 1, size=2, name="base_rv")
y_rv = pt.log(base_rv)
Expand Down