Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions src/boost_histogram/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,19 @@ def __dir__() -> list[str]:
WeightedMeanHists = TypeVar("WeightedMeanHists", bound="Histogram[bhs.WeightedMean]")


@typing.overload
def _fill_cast(
value: T, *, inner: bool = False
) -> T | np.typing.NDArray[Any] | tuple[T, ...]:
value: tuple[T, ...] | list[T], *, inner: Literal[False] = False
) -> tuple[T | np.typing.NDArray[Any], ...]: ...


@typing.overload
def _fill_cast(value: T, *, inner: bool = False) -> T | np.typing.NDArray[Any]: ...


def _fill_cast(
value: Any, *, inner: bool = False
) -> Any | np.typing.NDArray[Any] | tuple[Any | np.typing.NDArray[Any], ...]:
"""
Convert to NumPy arrays. Some buffer objects do not get converted by forcecast.
If not called by itself (inner=False), then will work through one level of tuple/list.
Expand All @@ -136,7 +146,7 @@ def _fill_cast(
return value

if not inner and isinstance(value, (tuple, list)):
return tuple(_fill_cast(a, inner=True) for a in value) # type: ignore[misc]
return tuple(_fill_cast(a, inner=True) for a in value)

if hasattr(value, "__iter__") or hasattr(value, "__array__"):
return np.asarray(value)
Expand Down Expand Up @@ -713,11 +723,23 @@ def fill(
weight_ars = _fill_cast(weight)
sample_ars = _fill_cast(sample)

# Broadcast scalar positional args to match sample length when sample is an array.
# This allows e.g. h.fill(0, sample=[1, 2, 3]) to work for Mean/WeightedMean storage.
if sample_ars is not None:
sample_arr = np.asarray(sample_ars)
if sample_arr.ndim > 0:
sample_len = len(sample_arr)
if sample_len > 1:
args_ars = tuple(
np.full(sample_len, a) if np.ndim(a) == 0 else a
for a in args_ars
)

if threads == 0:
threads = cpu_count()

if threads is None or threads == 1:
self._hist.fill(*args_ars, weight=weight_ars, sample=sample_ars) # type: ignore[arg-type]
self._hist.fill(*args_ars, weight=weight_ars, sample=sample_ars)
return self

if self._hist._storage_type in {
Expand All @@ -726,24 +748,26 @@ def fill(
}:
raise RuntimeError("Mean histograms do not support threaded filling")

data: list[list[np.typing.NDArray[Any]] | list[str]] = [
np.array_split(a, threads) if not isinstance(a, str) else [a] * threads # type: ignore[arg-type, list-item]
for a in args_ars
]
data: list[list[np.typing.NDArray[Any]] | list[str]] = []
for a in args_ars:
if isinstance(a, str):
data.append([a] * threads)
else:
data.append(np.array_split(np.asarray(a), threads))

weights: list[Any]
if weight is None or np.isscalar(weight):
assert threads is not None
weights = [weight_ars] * threads
else:
weights = np.array_split(weight_ars, threads) # type: ignore[arg-type]
weights = np.array_split(np.asarray(weight_ars), threads)

samples: list[Any]
if sample_ars is None or np.isscalar(sample_ars):
assert threads is not None
samples = [sample_ars] * threads
else:
samples = np.array_split(sample_ars, threads) # type: ignore[arg-type]
samples = np.array_split(np.asarray(sample_ars), threads)

if self._hist._storage_type is _core.storage.atomic_int64:

Expand Down
46 changes: 46 additions & 0 deletions tests/test_profiles.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import numpy as np
from pytest import approx

import boost_histogram as bh
Expand All @@ -25,3 +26,48 @@ def test_mean_hist():
assert results[i]["count"] == h[i].count
assert results[i]["value"] == approx(h[i].value)
assert results[i]["variance"] == approx(h[i].variance)


def test_mean_hist_scalar_axis_broadcast():
"""Test that a scalar axis value is broadcast to match a sample array (issue #727)."""
data = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50], [100, 200, 300, 400, 500]])
h = bh.Histogram(bh.axis.Integer(0, 3), storage=bh.storage.Mean())

h.fill(0, sample=data[0])
h.fill(1, sample=data[1])
h.fill(2, sample=data[2])

assert h[0].count == 5
assert h[0].value == approx(3.0)
assert h[1].count == 5
assert h[1].value == approx(30.0)
assert h[2].count == 5
assert h[2].value == approx(300.0)

# Verify equivalence with the explicit workaround
h2 = bh.Histogram(bh.axis.Integer(0, 3), storage=bh.storage.Mean())
h2.fill([0] * len(data[0]), sample=data[0])
h2.fill([1] * len(data[1]), sample=data[1])
h2.fill([2] * len(data[2]), sample=data[2])

for i in range(3):
assert h[i].count == h2[i].count
assert h[i].value == approx(h2[i].value)
assert h[i].variance == approx(h2[i].variance)


def test_weighted_mean_hist_scalar_axis_broadcast():
"""Test scalar axis broadcast for WeightedMean storage (issue #727)."""
data = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50], [100, 200, 300, 400, 500]])
h = bh.Histogram(bh.axis.Integer(0, 3), storage=bh.storage.WeightedMean())

h.fill(0, sample=data[0])
h.fill(1, sample=data[1])
h.fill(2, sample=data[2])

assert h[0].sum_of_weights == approx(5.0)
assert h[0].value == approx(3.0)
assert h[1].sum_of_weights == approx(5.0)
assert h[1].value == approx(30.0)
assert h[2].sum_of_weights == approx(5.0)
assert h[2].value == approx(300.0)
Loading