Skip to content

Commit 5aa7655

Browse files
authored
Refactor fit_gmm to support weighted particle resampling (#181)
1 parent fc8fb30 commit 5aa7655

2 files changed

Lines changed: 101 additions & 0 deletions

File tree

src/flekspy/amrex/particle_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,17 @@ def fit_gmm(
557557
if rdata.size == 0:
558558
raise ValueError("No particles to fit GMM.")
559559

560+
# Check for weights and resample if necessary
561+
resample_indices = None
562+
if "weight" in self.header.real_component_names:
563+
weight_idx = self.header.real_component_names.index("weight")
564+
weights = rdata[:, weight_idx]
565+
# Normalize weights to sum to 1
566+
total_weight = np.sum(weights)
567+
if total_weight > 0:
568+
p = weights / total_weight
569+
resample_indices = np.random.choice(len(rdata), size=len(rdata), p=p)
570+
560571
# --- 2. Apply transformation if provided ---
561572
component_names = self.header.real_component_names
562573
if transform:
@@ -565,6 +576,9 @@ def fit_gmm(
565576
# --- 3 & 4. Extract data columns ---
566577
data = self._extract_variable_columns(rdata, variables, component_names)
567578

579+
if resample_indices is not None:
580+
data = data[resample_indices]
581+
568582
from sklearn.mixture import GaussianMixture
569583

570584
# --- 5. Fit GMM ---

tests/test_weighted_gmm.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
import pytest
3+
from unittest.mock import MagicMock
4+
from flekspy.amrex.particle_data import AMReXParticle
5+
6+
class MockAMReXParticle(AMReXParticle):
7+
def __init__(self, rdata, header):
8+
self._rdata = rdata
9+
self.header = header
10+
self.output_dir = "mock_dir" # Dummy path
11+
# Mock other attributes needed by fit_gmm
12+
self._idata = np.empty((0, 0)) # Prevent load trigger if checked
13+
14+
@property
15+
def rdata(self):
16+
return self._rdata
17+
18+
def _extract_variable_columns(self, rdata, variables, component_names=None):
19+
# Determine column indices
20+
if component_names is None:
21+
component_names = self.header.real_component_names
22+
23+
indices = [component_names.index(var) for var in variables]
24+
return rdata[:, indices]
25+
26+
def select_particles_in_region(self, x_range=None, y_range=None, z_range=None):
27+
# For this test, we assume no region selection is actually performed or needed
28+
# fit_gmm calls this if ranges are provided.
29+
# If fit_gmm calls this, we just return all data for simplicity unless
30+
# we specifically test range selection (which we aren't here).
31+
return self._rdata
32+
33+
34+
@pytest.fixture
35+
def mock_weighted_data():
36+
"""
37+
Creates a MockAMReXParticle with weighted data.
38+
Two populations:
39+
1. Center 0, Weight 1
40+
2. Center 10, Weight 100
41+
Equal number of particles.
42+
"""
43+
rng = np.random.default_rng(42)
44+
n_per_group = 1000
45+
46+
# Group 1: Center 0, weight 1
47+
g1_x = rng.normal(0, 0.1, n_per_group)
48+
g1_w = np.ones(n_per_group)
49+
50+
# Group 2: Center 10, weight 100
51+
g2_x = rng.normal(10, 0.1, n_per_group)
52+
g2_w = np.full(n_per_group, 100.0)
53+
54+
x = np.concatenate([g1_x, g2_x])
55+
w = np.concatenate([g1_w, g2_w])
56+
57+
# Dummy y for 2D requirement of fit_gmm if needed, though we fit 1D "x" mostly
58+
y = np.zeros_like(x)
59+
60+
# Create rdata: columns [x, y, weight]
61+
rdata = np.column_stack([x, y, w])
62+
63+
header = MagicMock()
64+
header.real_component_names = ["x", "y", "weight"]
65+
66+
return MockAMReXParticle(rdata, header)
67+
68+
def test_fit_gmm_weighted(mock_weighted_data):
69+
"""
70+
Tests that fit_gmm respects particle weights.
71+
Without weighting: Mean should be ~5 (average of 0 and 10).
72+
With weighting (1 vs 100): Mean should be close to 10.
73+
"""
74+
# Fit GMM on 'x'
75+
# We pass variables=['x'] to fit 1D
76+
gmm = mock_weighted_data.fit_gmm(n_components=1, variables=['x'])
77+
78+
mean = gmm.means_[0][0]
79+
80+
# If weights are ignored, mean is (0 + 10) / 2 = 5
81+
# If weights are respected, mean is (1*0 + 100*10) / 101 ~= 9.9
82+
83+
print(f"GMM Mean: {mean}")
84+
85+
# Assert that the mean is significantly higher than 5, indicating weights were used.
86+
# We use a loose bound to account for randomness, but 5 vs 9.9 is huge.
87+
assert mean > 8.0, f"Mean {mean} is too low, weights likely ignored."

0 commit comments

Comments
 (0)