-
Notifications
You must be signed in to change notification settings - Fork 31
More flexible fitting function, allow likelihood, remove uncertainties dependency #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
1ded859
93a59ab
155ae98
80cb32c
95fcff6
3d6ca17
0de0869
087a9ca
322bd30
150ed5d
49e02f2
791194c
88afa48
b4f4520
55176fc
78b5070
3b373a7
a149b98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| include LICENSE pyproject.toml README.md setup.py setup.cfg | ||
| recursive-include tests *.py | ||
| recursive-include tests *.py *.png | ||
| recursive-include src *.py | ||
| recursive-include src py.typed |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,13 @@ | ||
| import inspect | ||
| import sys | ||
| from typing import Any, Callable, Dict, Optional, Set, Tuple, Union | ||
| from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union | ||
|
|
||
| import numpy as np | ||
|
|
||
| import hist | ||
|
|
||
| from .typing import ArrayLike | ||
|
|
||
| try: | ||
| import matplotlib.axes | ||
| import matplotlib.patches as patches | ||
|
|
@@ -43,6 +46,89 @@ def _filter_dict( | |
| } | ||
|
|
||
|
|
||
| def _expr_to_lambda(expr: str) -> Callable[..., Any]: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the default for Callable - best to not leave empty Generics. |
||
| """ | ||
| Converts a string expression like | ||
| "a+b*np.exp(-c*x+math.pi)" | ||
| into a callable function with 1 variable and N parameters, | ||
| lambda x,a,b,c: "a+b*np.exp(-c*x+math.pi)" | ||
| `x` is assumed to be the main variable, and preventing symbols | ||
| like `foo.bar` or `foo(` from being considered as parameter. | ||
| """ | ||
| from collections import OrderedDict | ||
| from io import BytesIO | ||
| from tokenize import NAME, tokenize | ||
|
|
||
| varnames = [] | ||
| g = list(tokenize(BytesIO(expr.encode("utf-8")).readline)) | ||
| for ix, x in enumerate(g): | ||
| toknum = x[0] | ||
| tokval = x[1] | ||
| if toknum != NAME: | ||
| continue | ||
| if ix > 0 and g[ix - 1][1] in ["."]: | ||
| continue | ||
| if ix < len(g) - 1 and g[ix + 1][1] in [".", "("]: | ||
| continue | ||
| varnames.append(tokval) | ||
| varnames = list(OrderedDict.fromkeys([name for name in varnames if name != "x"])) | ||
| lambdastr = f"lambda x,{','.join(varnames)}: {expr}" | ||
| return eval(lambdastr) # type: ignore | ||
|
|
||
|
|
||
| def _curve_fit_wrapper( | ||
| func: Callable[..., Any], | ||
| xdata: np.ndarray, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ArrayLike can be a simple number, so |
||
| ydata: np.ndarray, | ||
| yerr: np.ndarray, | ||
| likelihood: bool = False, | ||
| ) -> Tuple[Tuple[float, ...], ArrayLike]: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes the typing easier, because NumPy's typing is a bit spotty. It doesn't know that |
||
| """ | ||
| Wrapper around `scipy.optimize.curve_fit`. Initial parameters (`p0`) | ||
| can be set in the function definition with defaults for kwargs | ||
| (e.g., `func = lambda x,a=1.,b=2.: x+a+b`, will feed `p0 = [1.,2.]` to `curve_fit`) | ||
| """ | ||
| from scipy.optimize import curve_fit, minimize | ||
|
|
||
| params = list(inspect.signature(func).parameters.values()) | ||
| p0 = [ | ||
| 1 if arg.default is inspect.Parameter.empty else arg.default | ||
| for arg in params[1:] | ||
| ] | ||
|
|
||
| mask = yerr != 0.0 | ||
| popt, pcov = curve_fit( | ||
| func, | ||
| xdata[mask], | ||
| ydata[mask], | ||
| sigma=yerr[mask], | ||
| absolute_sigma=True, | ||
| p0=p0, | ||
| ) | ||
| if likelihood: | ||
| from iminuit import Minuit | ||
| from scipy.special import gammaln | ||
|
|
||
| def fnll(v: Iterable[np.ndarray]) -> float: | ||
| ypred = func(xdata, *v) | ||
| if (ypred <= 0.0).any(): | ||
| return 1e6 | ||
| return ( # type: ignore | ||
| ypred.sum() - (ydata * np.log(ypred)).sum() + gammaln(ydata + 1).sum() | ||
| ) | ||
|
|
||
| # Seed likelihood fit with chi2 fit parameters | ||
| res = minimize(fnll, popt, method="BFGS") | ||
| popt = res.x | ||
|
|
||
| # Better hessian from hesse, seeded with scipy popt | ||
| m = Minuit(fnll, popt) | ||
| m.errordef = 0.5 | ||
| m.hesse() | ||
| pcov = np.array(m.covariance) | ||
| return tuple(popt), pcov | ||
|
|
||
|
|
||
| def plot2d_full( | ||
| self: hist.BaseHist, | ||
| *, | ||
|
|
@@ -128,7 +214,8 @@ def plot2d_full( | |
|
|
||
| def plot_pull( | ||
| self: hist.BaseHist, | ||
| func: Callable[[np.ndarray], np.ndarray], | ||
| func: Union[Callable[[np.ndarray], np.ndarray], str], | ||
| likelihood: bool = False, | ||
| *, | ||
| ax_dict: "Optional[Dict[str, matplotlib.axes.Axes]]" = None, | ||
| **kwargs: Any, | ||
|
|
@@ -138,18 +225,18 @@ def plot_pull( | |
| """ | ||
|
|
||
| try: | ||
| from scipy.optimize import curve_fit | ||
| from uncertainties import correlated_values, unumpy | ||
| from iminuit import Minuit # noqa: F401 | ||
| from scipy.optimize import curve_fit # noqa: F401 | ||
| except ImportError: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be a ModuleNotFoundError. |
||
| print( | ||
| "Hist.plot_pull requires scipy and uncertainties. Please install hist[plot] or manually install dependencies.", | ||
| "Hist.plot_pull requires scipy and iminuit. Please install hist[plot] or manually install dependencies.", | ||
| file=sys.stderr, | ||
| ) | ||
| raise | ||
|
|
||
| # Type judgement | ||
| if not callable(func): | ||
| msg = f"Callable parameter func is supported for {self.__class__.__name__} in plot pull" | ||
| if not callable(func) and not type(func) in [str]: | ||
| msg = f"Parameter func must be callable or a string for {self.__class__.__name__} in plot pull" | ||
| raise TypeError(msg) | ||
|
|
||
| if self.ndim != 1: | ||
|
|
@@ -169,27 +256,53 @@ def plot_pull( | |
| pull_ax = fig.add_subplot(grid[1], sharex=main_ax) | ||
|
|
||
| # Computation and Fit | ||
| values = self.values() | ||
| xdata = self.axes[0].centers | ||
| ydata = self.values() | ||
| variances = self.variances() | ||
| if variances is None: | ||
| raise RuntimeError( | ||
| "Cannot compute from a variance-less histogram, try a Weight storage" | ||
| ) | ||
| yerr = np.sqrt(variances) | ||
|
|
||
| # Compute fit values: using func as fit model | ||
| popt, pcov = curve_fit(f=func, xdata=self.axes[0].centers, ydata=values) | ||
| fit = func(self.axes[0].centers, *popt) | ||
| if isinstance(func, str): | ||
| if func == "gaus": | ||
| # gaussian with reasonable initial guesses for parameters | ||
| constant = float(ydata.max()) | ||
| mean = (ydata * xdata).sum() / ydata.sum() | ||
| sigma = (ydata * (xdata - mean) ** 2.0).sum() / ydata.sum() | ||
|
|
||
| def func( | ||
| x: np.ndarray, | ||
| constant: float = constant, | ||
| mean: float = mean, | ||
| sigma: float = sigma, | ||
| ) -> np.ndarray: | ||
| return constant * np.exp(-((x - mean) ** 2.0) / (2 * sigma ** 2)) # type: ignore | ||
|
|
||
| else: | ||
| func = _expr_to_lambda(func) | ||
|
|
||
| assert not isinstance(func, str) | ||
|
|
||
| # Compute uncertainty | ||
| copt = correlated_values(popt, pcov) | ||
| y_unc = func(self.axes[0].centers, *copt) | ||
| y_nv = unumpy.nominal_values(y_unc) | ||
| y_sd = unumpy.std_devs(y_unc) | ||
| parnames = list(inspect.signature(func).parameters)[1:] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same higher level usage of |
||
|
|
||
| # Compute fit values: using func as fit model | ||
| popt, pcov = _curve_fit_wrapper(func, xdata, ydata, yerr, likelihood=likelihood) | ||
| perr = np.diag(pcov) ** 0.5 | ||
| yfit = func(self.axes[0].centers, *popt) | ||
|
|
||
| if np.isfinite(pcov).all(): | ||
| nsamples = 100 | ||
| vopts = np.random.multivariate_normal(popt, pcov, nsamples) | ||
| sampled_ydata = np.vstack([func(xdata, *vopt).T for vopt in vopts]) | ||
| yfiterr = np.nanstd(sampled_ydata, axis=0) | ||
| else: | ||
| yfiterr = np.zeros_like(yerr) | ||
|
|
||
| # Compute pulls: containing no INF values | ||
| with np.errstate(divide="ignore"): | ||
| pulls = (values - y_nv) / yerr | ||
| pulls = (ydata - yfit) / yerr | ||
|
|
||
| pulls[np.isnan(pulls)] = 0 | ||
| pulls[np.isinf(pulls)] = 0 | ||
|
|
@@ -201,12 +314,17 @@ def plot_pull( | |
| eb_kwargs.setdefault("label", "Histogram Data") | ||
|
|
||
| # fit plot keyword arguments | ||
| label = "Fit" | ||
| for name, value, error in zip(parnames, popt, perr): | ||
| label += "\n " | ||
| label += rf"{name} = {value:.3g} $\pm$ {error:.3g}" | ||
| fp_kwargs = _filter_dict(kwargs, "fp_") | ||
| fp_kwargs.setdefault("label", "Fitting Value") | ||
| fp_kwargs.setdefault("label", label) | ||
|
|
||
| # uncertainty band keyword arguments | ||
| ub_kwargs = _filter_dict(kwargs, "ub_") | ||
| ub_kwargs.setdefault("label", "Uncertainty") | ||
| ub_kwargs.setdefault("alpha", 0.5) | ||
|
|
||
| # bar plot keyword arguments | ||
| bar_kwargs = _filter_dict(kwargs, "bar_", ignore={"bar_width"}) | ||
|
|
@@ -220,16 +338,16 @@ def plot_pull( | |
| raise ValueError(f"{set(kwargs)}' not needed") | ||
|
|
||
| # Main: plot the pulls using Matplotlib errorbar and plot methods | ||
| main_ax.errorbar(self.axes.centers[0], values, yerr, **eb_kwargs) | ||
| main_ax.errorbar(self.axes.centers[0], ydata, yerr, **eb_kwargs) | ||
|
|
||
| (line,) = main_ax.plot(self.axes.centers[0], fit, **fp_kwargs) | ||
| (line,) = main_ax.plot(self.axes.centers[0], yfit, **fp_kwargs) | ||
|
|
||
| # Uncertainty band | ||
| ub_kwargs.setdefault("color", line.get_color()) | ||
| main_ax.fill_between( | ||
| self.axes.centers[0], | ||
| y_nv - y_sd, | ||
| y_nv + y_sd, | ||
| yfit - yfiterr, | ||
| yfit + yfiterr, | ||
| **ub_kwargs, | ||
| ) | ||
| main_ax.legend(loc=0) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.