|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | +from unittest.mock import MagicMock |
| 4 | +from flekspy.amrex.particle_data import AMReXParticle |
| 5 | + |
| 6 | +class MockAMReXParticle(AMReXParticle): |
| 7 | + def __init__(self, rdata, header): |
| 8 | + self._rdata = rdata |
| 9 | + self.header = header |
| 10 | + self.output_dir = "mock_dir" # Dummy path |
| 11 | + # Mock other attributes needed by fit_gmm |
| 12 | + self._idata = np.empty((0, 0)) # Prevent load trigger if checked |
| 13 | + |
| 14 | + @property |
| 15 | + def rdata(self): |
| 16 | + return self._rdata |
| 17 | + |
| 18 | + def _extract_variable_columns(self, rdata, variables, component_names=None): |
| 19 | + # Determine column indices |
| 20 | + if component_names is None: |
| 21 | + component_names = self.header.real_component_names |
| 22 | + |
| 23 | + indices = [component_names.index(var) for var in variables] |
| 24 | + return rdata[:, indices] |
| 25 | + |
| 26 | + def select_particles_in_region(self, x_range=None, y_range=None, z_range=None): |
| 27 | + # For this test, we assume no region selection is actually performed or needed |
| 28 | + # fit_gmm calls this if ranges are provided. |
| 29 | + # If fit_gmm calls this, we just return all data for simplicity unless |
| 30 | + # we specifically test range selection (which we aren't here). |
| 31 | + return self._rdata |
| 32 | + |
| 33 | + |
| 34 | +@pytest.fixture |
| 35 | +def mock_weighted_data(): |
| 36 | + """ |
| 37 | + Creates a MockAMReXParticle with weighted data. |
| 38 | + Two populations: |
| 39 | + 1. Center 0, Weight 1 |
| 40 | + 2. Center 10, Weight 100 |
| 41 | + Equal number of particles. |
| 42 | + """ |
| 43 | + rng = np.random.default_rng(42) |
| 44 | + n_per_group = 1000 |
| 45 | + |
| 46 | + # Group 1: Center 0, weight 1 |
| 47 | + g1_x = rng.normal(0, 0.1, n_per_group) |
| 48 | + g1_w = np.ones(n_per_group) |
| 49 | + |
| 50 | + # Group 2: Center 10, weight 100 |
| 51 | + g2_x = rng.normal(10, 0.1, n_per_group) |
| 52 | + g2_w = np.full(n_per_group, 100.0) |
| 53 | + |
| 54 | + x = np.concatenate([g1_x, g2_x]) |
| 55 | + w = np.concatenate([g1_w, g2_w]) |
| 56 | + |
| 57 | + # Dummy y for 2D requirement of fit_gmm if needed, though we fit 1D "x" mostly |
| 58 | + y = np.zeros_like(x) |
| 59 | + |
| 60 | + # Create rdata: columns [x, y, weight] |
| 61 | + rdata = np.column_stack([x, y, w]) |
| 62 | + |
| 63 | + header = MagicMock() |
| 64 | + header.real_component_names = ["x", "y", "weight"] |
| 65 | + |
| 66 | + return MockAMReXParticle(rdata, header) |
| 67 | + |
| 68 | +def test_fit_gmm_weighted(mock_weighted_data): |
| 69 | + """ |
| 70 | + Tests that fit_gmm respects particle weights. |
| 71 | + Without weighting: Mean should be ~5 (average of 0 and 10). |
| 72 | + With weighting (1 vs 100): Mean should be close to 10. |
| 73 | + """ |
| 74 | + # Fit GMM on 'x' |
| 75 | + # We pass variables=['x'] to fit 1D |
| 76 | + gmm = mock_weighted_data.fit_gmm(n_components=1, variables=['x']) |
| 77 | + |
| 78 | + mean = gmm.means_[0][0] |
| 79 | + |
| 80 | + # If weights are ignored, mean is (0 + 10) / 2 = 5 |
| 81 | + # If weights are respected, mean is (1*0 + 100*10) / 101 ~= 9.9 |
| 82 | + |
| 83 | + print(f"GMM Mean: {mean}") |
| 84 | + |
| 85 | + # Assert that the mean is significantly higher than 5, indicating weights were used. |
| 86 | + # We use a loose bound to account for randomness, but 5 vs 9.9 is huge. |
| 87 | + assert mean > 8.0, f"Mean {mean} is too low, weights likely ignored." |
0 commit comments