-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add logprob support for leaky-ReLU switch transforms #7995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
36705b4
4d620a3
565e191
7d1db41
f9f36c0
4530814
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -43,6 +43,8 @@ | |||||||||
|
|
||||||||||
| from pytensor.graph.basic import equal_computations | ||||||||||
|
|
||||||||||
| import pymc as pm | ||||||||||
|
|
||||||||||
| from pymc.distributions.continuous import Cauchy, ChiSquared | ||||||||||
| from pymc.distributions.discrete import Bernoulli | ||||||||||
| from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp | ||||||||||
|
|
@@ -219,6 +221,7 @@ def test_exp_transform_rv(): | |||||||||
| logp_fn(y_val), | ||||||||||
| sp.stats.lognorm(s=1).logpdf(y_val), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| np.testing.assert_almost_equal( | ||||||||||
| logcdf_fn(y_val), | ||||||||||
| sp.stats.lognorm(s=1).logcdf(y_val), | ||||||||||
|
|
@@ -229,6 +232,57 @@ def test_exp_transform_rv(): | |||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_leaky_relu_switch_logp_scalar(): | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tests should be moved to a
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
oh right, that's my bad
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
sure |
||||||||||
| a = 0.5 | ||||||||||
| x = pm.Normal.dist(mu=0, sigma=1) | ||||||||||
| y = pm.math.switch(x > 0, x, a * x) | ||||||||||
|
|
||||||||||
| v_pos = 1.2 | ||||||||||
| np.testing.assert_allclose( | ||||||||||
| pm.logp(y, v_pos, warn_rvs=False).eval(), | ||||||||||
| pm.logp(x, v_pos, warn_rvs=False).eval(), | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| ) | ||||||||||
|
|
||||||||||
| v_neg = -2.0 | ||||||||||
| np.testing.assert_allclose( | ||||||||||
| pm.logp(y, v_neg, warn_rvs=False).eval(), | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you're testing with two values, define the logp variable once (or compile a function with the logp as output once), and reuse it. That will avoid duplicated logp inference calls. |
||||||||||
| pm.logp(x, v_neg / a, warn_rvs=False).eval() - np.log(a), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # boundary point (measure-zero for continuous RVs): should still produce a finite logp | ||||||||||
| assert np.isfinite(pm.logp(y, 0.0, warn_rvs=False).eval()) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_leaky_relu_switch_logp_vectorized(): | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use vectorized in the first test, and remove this. It's not conceputally that different |
||||||||||
| a = 0.5 | ||||||||||
| x = pm.Normal.dist(mu=0, sigma=1, size=(3,)) | ||||||||||
| y = pm.math.switch(x > 0, x, a * x) | ||||||||||
|
|
||||||||||
| v = np.array([-2.0, 0.0, 1.5]) | ||||||||||
| expected = pm.logp(x, np.where(v > 0, v, v / a), warn_rvs=False).eval() + np.where( | ||||||||||
| v > 0, 0.0, -np.log(a) | ||||||||||
| ) | ||||||||||
| np.testing.assert_allclose(pm.logp(y, v, warn_rvs=False).eval(), expected) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_leaky_relu_switch_logp_symbolic_slope_checks_positive(): | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use symbolic a in the first test and test the error there. Still a pretty straightforward test |
||||||||||
| a = pt.scalar("a") | ||||||||||
| x = pm.Normal.dist(mu=0, sigma=1) | ||||||||||
| y = pm.math.switch(x > 0, x, a * x) | ||||||||||
|
|
||||||||||
| # positive slope passes | ||||||||||
| res = pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.5}) | ||||||||||
| expected = pm.logp(x, -1.0 / 0.5, warn_rvs=False).eval() - np.log(0.5) | ||||||||||
| np.testing.assert_allclose(res, expected) | ||||||||||
|
|
||||||||||
| # non pos slope raises | ||||||||||
| with pytest.raises(ParameterValueError, match="leaky_relu slope > 0"): | ||||||||||
| pm.logp(y, -1.0, warn_rvs=False).eval({a: -0.5}) | ||||||||||
|
|
||||||||||
| with pytest.raises(ParameterValueError, match="leaky_relu slope > 0"): | ||||||||||
| pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.0}) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_log_transform_rv(): | ||||||||||
| base_rv = pt.random.lognormal(0, 1, size=2, name="base_rv") | ||||||||||
| y_rv = pt.log(base_rv) | ||||||||||
|
|
||||||||||
Uh oh!
There was an error while loading. Please reload this page.