Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ jobs:
run: python -m pip install -e ".[test,plot]"

- name: Test plotting too
run: python -m pytest
run: python -m pytest --mpl
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include LICENSE pyproject.toml README.md setup.py setup.cfg
recursive-include tests *.py
recursive-include tests *.py *.png
recursive-include src *.py
recursive-include src py.typed
1 change: 1 addition & 0 deletions dev-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- jupyterlab >=1.2
- matplotlib >=3.1
- pytest >=5
- pytest-mpl >=0.12
- setuptools >=42
- setuptools_scm >=3.4
- scipy >=1.4.1
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
extras_require["plot"] = [
"matplotlib >=3.0",
"scipy >=1.4",
"uncertainties >=3",
"iminuit >=2",
"mplhep >=0.2.16",
]

extras_require["test"] = [
*extras_require["plot"],
"pytest >=4.6",
"pytest-mpl >=0.12",
Comment thread
henryiii marked this conversation as resolved.
]

extras_require["dev"] = [*extras_require["test"], *extras_require["plot"], "ipykernel"]
Expand Down
162 changes: 140 additions & 22 deletions src/hist/plot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import inspect
import sys
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union

import numpy as np

import hist

from .typing import ArrayLike

try:
import matplotlib.axes
import matplotlib.patches as patches
Expand Down Expand Up @@ -43,6 +46,89 @@ def _filter_dict(
}


def _expr_to_lambda(expr: str) -> Callable[..., Any]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the default for Callable - best to not leave empty Generics.

"""
Converts a string expression like
"a+b*np.exp(-c*x+math.pi)"
into a callable function with 1 variable and N parameters,
lambda x,a,b,c: "a+b*np.exp(-c*x+math.pi)"
`x` is assumed to be the main variable, and preventing symbols
like `foo.bar` or `foo(` from being considered as parameter.
"""
from collections import OrderedDict
from io import BytesIO
from tokenize import NAME, tokenize

varnames = []
g = list(tokenize(BytesIO(expr.encode("utf-8")).readline))
for ix, x in enumerate(g):
toknum = x[0]
tokval = x[1]
if toknum != NAME:
continue
if ix > 0 and g[ix - 1][1] in ["."]:
continue
if ix < len(g) - 1 and g[ix + 1][1] in [".", "("]:
continue
varnames.append(tokval)
varnames = list(OrderedDict.fromkeys([name for name in varnames if name != "x"]))
lambdastr = f"lambda x,{','.join(varnames)}: {expr}"
return eval(lambdastr) # type: ignore


def _curve_fit_wrapper(
func: Callable[..., Any],
xdata: np.ndarray,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayLike can be a simple number, so xdata[stuff] was not valid. Either use np.asarray, or just require arrays.

ydata: np.ndarray,
yerr: np.ndarray,
likelihood: bool = False,
) -> Tuple[Tuple[float, ...], ArrayLike]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes the typing easier, because NumPy's typing is a bit spotty. It doesn't know that *array is valid, etc. This is small so was a simple fix.

"""
Wrapper around `scipy.optimize.curve_fit`. Initial parameters (`p0`)
can be set in the function definition with defaults for kwargs
(e.g., `func = lambda x,a=1.,b=2.: x+a+b`, will feed `p0 = [1.,2.]` to `curve_fit`)
"""
from scipy.optimize import curve_fit, minimize

params = list(inspect.signature(func).parameters.values())
p0 = [
1 if arg.default is inspect.Parameter.empty else arg.default
for arg in params[1:]
]

mask = yerr != 0.0
popt, pcov = curve_fit(
func,
xdata[mask],
ydata[mask],
sigma=yerr[mask],
absolute_sigma=True,
p0=p0,
)
if likelihood:
from iminuit import Minuit
from scipy.special import gammaln

def fnll(v: Iterable[np.ndarray]) -> float:
ypred = func(xdata, *v)
if (ypred <= 0.0).any():
return 1e6
return ( # type: ignore
ypred.sum() - (ydata * np.log(ypred)).sum() + gammaln(ydata + 1).sum()
)

# Seed likelihood fit with chi2 fit parameters
res = minimize(fnll, popt, method="BFGS")
popt = res.x

# Better hessian from hesse, seeded with scipy popt
m = Minuit(fnll, popt)
m.errordef = 0.5
m.hesse()
pcov = np.array(m.covariance)
return tuple(popt), pcov


def plot2d_full(
self: hist.BaseHist,
*,
Expand Down Expand Up @@ -128,7 +214,8 @@ def plot2d_full(

def plot_pull(
self: hist.BaseHist,
func: Callable[[np.ndarray], np.ndarray],
func: Union[Callable[[np.ndarray], np.ndarray], str],
likelihood: bool = False,
*,
ax_dict: "Optional[Dict[str, matplotlib.axes.Axes]]" = None,
**kwargs: Any,
Expand All @@ -138,18 +225,18 @@ def plot_pull(
"""

try:
from scipy.optimize import curve_fit
from uncertainties import correlated_values, unumpy
from iminuit import Minuit # noqa: F401
from scipy.optimize import curve_fit # noqa: F401
except ImportError:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a ModuleNotFoundError.

print(
"Hist.plot_pull requires scipy and uncertainties. Please install hist[plot] or manually install dependencies.",
"Hist.plot_pull requires scipy and iminuit. Please install hist[plot] or manually install dependencies.",
file=sys.stderr,
)
raise

# Type judgement
if not callable(func):
msg = f"Callable parameter func is supported for {self.__class__.__name__} in plot pull"
if not callable(func) and not type(func) in [str]:
msg = f"Parameter func must be callable or a string for {self.__class__.__name__} in plot pull"
raise TypeError(msg)

if self.ndim != 1:
Expand All @@ -169,27 +256,53 @@ def plot_pull(
pull_ax = fig.add_subplot(grid[1], sharex=main_ax)

# Computation and Fit
values = self.values()
xdata = self.axes[0].centers
ydata = self.values()
variances = self.variances()
if variances is None:
raise RuntimeError(
"Cannot compute from a variance-less histogram, try a Weight storage"
)
yerr = np.sqrt(variances)

# Compute fit values: using func as fit model
popt, pcov = curve_fit(f=func, xdata=self.axes[0].centers, ydata=values)
fit = func(self.axes[0].centers, *popt)
if isinstance(func, str):
if func == "gaus":
# gaussian with reasonable initial guesses for parameters
constant = float(ydata.max())
mean = (ydata * xdata).sum() / ydata.sum()
sigma = (ydata * (xdata - mean) ** 2.0).sum() / ydata.sum()

def func(
x: np.ndarray,
constant: float = constant,
mean: float = mean,
sigma: float = sigma,
) -> np.ndarray:
return constant * np.exp(-((x - mean) ** 2.0) / (2 * sigma ** 2)) # type: ignore

else:
func = _expr_to_lambda(func)

assert not isinstance(func, str)

# Compute uncertainty
copt = correlated_values(popt, pcov)
y_unc = func(self.axes[0].centers, *copt)
y_nv = unumpy.nominal_values(y_unc)
y_sd = unumpy.std_devs(y_unc)
parnames = list(inspect.signature(func).parameters)[1:]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same higher level usage of inspect.signature over .__code__.co_varname.


# Compute fit values: using func as fit model
popt, pcov = _curve_fit_wrapper(func, xdata, ydata, yerr, likelihood=likelihood)
perr = np.diag(pcov) ** 0.5
yfit = func(self.axes[0].centers, *popt)

if np.isfinite(pcov).all():
nsamples = 100
vopts = np.random.multivariate_normal(popt, pcov, nsamples)
sampled_ydata = np.vstack([func(xdata, *vopt).T for vopt in vopts])
yfiterr = np.nanstd(sampled_ydata, axis=0)
else:
yfiterr = np.zeros_like(yerr)

# Compute pulls: containing no INF values
with np.errstate(divide="ignore"):
pulls = (values - y_nv) / yerr
pulls = (ydata - yfit) / yerr

pulls[np.isnan(pulls)] = 0
pulls[np.isinf(pulls)] = 0
Expand All @@ -201,12 +314,17 @@ def plot_pull(
eb_kwargs.setdefault("label", "Histogram Data")

# fit plot keyword arguments
label = "Fit"
for name, value, error in zip(parnames, popt, perr):
label += "\n "
label += rf"{name} = {value:.3g} $\pm$ {error:.3g}"
fp_kwargs = _filter_dict(kwargs, "fp_")
fp_kwargs.setdefault("label", "Fitting Value")
fp_kwargs.setdefault("label", label)

# uncertainty band keyword arguments
ub_kwargs = _filter_dict(kwargs, "ub_")
ub_kwargs.setdefault("label", "Uncertainty")
ub_kwargs.setdefault("alpha", 0.5)

# bar plot keyword arguments
bar_kwargs = _filter_dict(kwargs, "bar_", ignore={"bar_width"})
Expand All @@ -220,16 +338,16 @@ def plot_pull(
raise ValueError(f"{set(kwargs)}' not needed")

# Main: plot the pulls using Matplotlib errorbar and plot methods
main_ax.errorbar(self.axes.centers[0], values, yerr, **eb_kwargs)
main_ax.errorbar(self.axes.centers[0], ydata, yerr, **eb_kwargs)

(line,) = main_ax.plot(self.axes.centers[0], fit, **fp_kwargs)
(line,) = main_ax.plot(self.axes.centers[0], yfit, **fp_kwargs)

# Uncertainty band
ub_kwargs.setdefault("color", line.get_color())
main_ax.fill_between(
self.axes.centers[0],
y_nv - y_sd,
y_nv + y_sd,
yfit - yfiterr,
yfit + yfiterr,
**ub_kwargs,
)
main_ax.legend(loc=0)
Expand Down
Binary file added tests/baseline/test_image_plot_pull.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 41 additions & 7 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from hist import Hist, NamedHist, axis

unp = pytest.importorskip("uncertainties.unumpy")
plt = pytest.importorskip("matplotlib.pyplot")


Expand Down Expand Up @@ -200,15 +199,16 @@ def test_general_plot_pull():
Test general plot_pull -- whether 1d-Hist can be plotted pull properly.
"""

np.random.seed(42)

h = Hist(
axis.Regular(
50, -4, 4, name="S", label="s [units]", underflow=False, overflow=False
)
).fill(np.random.normal(size=10))

def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0):
exp = unp.exp if a.dtype == np.dtype("O") else np.exp
return a * exp(-((x - x0) ** 2) / (2 * sigma ** 2)) + offset
return a * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2)) + offset

assert h.plot_pull(
pdf,
Expand All @@ -231,6 +231,10 @@ def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0):
pp_ec=None,
)

pdf_str = "a * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2)) + offset"

assert h.plot_pull(pdf_str)

# dimension error
hh = Hist(
axis.Regular(
Expand All @@ -244,7 +248,7 @@ def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0):
with pytest.raises(Exception):
hh.plot_pull(pdf)

# not callable
# no eval-able variable
with pytest.raises(Exception):
h.plot_pull("1")

Expand Down Expand Up @@ -490,15 +494,16 @@ def test_named_plot_pull():
Test named plot_pull -- whether 1d-NamedHist can be plotted pull properly.
"""

np.random.seed(42)

h = NamedHist(
axis.Regular(
50, -4, 4, name="S", label="s [units]", underflow=False, overflow=False
)
).fill(S=np.random.normal(size=10))

def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0):
exp = unp.exp if a.dtype == np.dtype("O") else np.exp
return a * exp(-((x - x0) ** 2) / (2 * sigma ** 2)) + offset
return a * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2)) + offset

assert h.plot_pull(
pdf,
Expand All @@ -521,6 +526,10 @@ def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0):
pp_ec=None,
)

pdf_str = "a * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2)) + offset"

assert h.plot_pull(pdf_str)

# dimension error
hh = NamedHist(
axis.Regular(
Expand All @@ -534,7 +543,7 @@ def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0):
with pytest.raises(Exception):
hh.plot_pull(pdf)

# not callable
# no eval-able variable
with pytest.raises(Exception):
h.plot_pull("1")

Expand Down Expand Up @@ -581,3 +590,28 @@ def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0):
h.plot_pull(pdf, eb_ecolor=1.0, eb_mfc=1.0) # kwargs should be str

plt.close("all")


@pytest.mark.mpl_image_compare(baseline_dir="baseline", savefig_kwargs={"dpi": 50})
def test_image_plot_pull():
"""
Test plot_pull by comparing against a reference image generated via
`pytest --mpl-generate-path=baseline`
"""

np.random.seed(42)

h = Hist(
axis.Regular(
50, -4, 4, name="S", label="s [units]", underflow=False, overflow=False
)
).fill(np.random.normal(size=100))

def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0):
return a * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2)) + offset

fig, ax = plt.subplots()

assert h.plot_pull(pdf)

return fig