Re-enable time-dependent z-scoring for Flow Matching#1752
Re-enable time-dependent z-scoring for Flow Matching#1752manuelgloeckler merged 22 commits intosbi-dev:mainfrom
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1752 +/- ##
==========================================
- Coverage 87.88% 87.82% -0.06%
==========================================
Files 140 140
Lines 12726 12777 +51
==========================================
+ Hits 11184 11222 +38
- Misses 1542 1555 +13
Flags with carried forward coverage won't be shown. Click here to find out more.
|
|
It seems Since this failure is in The actual Flow Matching benchmarks and integration tests for this PR passed successfully though |
Yes, this is unrelated and popped up here by chance or because of an unrelated change in a downstream package. I pushed a fix to this branch ✅ |
|
Thanks for working on this @satwiksps ! Overall, this looks exactly right. However, after reviewing the code and tracing through the flow matching implementation, I believe the z-scoring formula is inverted relative to the interpolation convention (quite confusing!) The interpolation in the loss function is:
So the expected input mean at each time is:
Current PR formula:
This gives mu_t = 0 at t=0 and mu_t = mean_data at t=1 — exactly backwards. Correct formula should be:
The formula only matches at t=0.5 and is maximally wrong at the boundaries. Note on To verify this, I suggest the following test: The standard linear Gaussian test, but with uniform prior between 95 and 100, and with data Can you confirm this (maybe I got confused with the integration directions after all)? |
manuelgloeckler
left a comment
There was a problem hiding this comment.
Hey @satwiksps !
Thanks for the contribution! I checked with main and as of now it does I guess on average perform very similar if not a bit worse than before (although, I think thats mostly fine i.e. these tasks).
I wonder if it would make sense to improve the "preconditioning" a bit more (see comments).
Thanks for adding the comparison to |
|
Alright, I looked at it again and I realized that my proposal was actually incorrect. The formulas I proposed would result in total normalization, i.e., "independent" z-scoring, where all time steps have equal zero mean after z-scoring and we lose valuable time-depenedent information - sorry @satwiksps , your formulas where actually correct! What Manuel proposed is great, we z-score with respect to the Gaussian baseline, e.g., what one would expect when the posterior is actually Gaussian. Then the flow matching network only has to learn the residual from this ideal baseline (please correct me @manuelgloeckler if this intuition is inaccurate). I tested this locally with the following setup:
Results:
Thus, @satwiksps I suggest you implement both options, your proposal and Manuel's proposal and add the test as a new z-scoring test and confirm the results. |
|
@janfb The preconditioning is with respect to the "prior" not the posterior (as this would require regression from x). I don't think that it will "hurt" in almost all cases i.e. FM nets are initialized to output zero hence effectively will let the initialized network sample from a mass covering Gaussian approximation of the prior (and everything else needs to be learned). Nonetheless having an option to disable it is always good. Agree that the benchmark tests are not really sensitive to the z-scoreing, but as we usually enable z-scoreing by default it shouldn't hurt performance even if its not necessary. But as said the deviation is small enough to be fine (and might improve with the additional baseline). |
janfb
left a comment
There was a problem hiding this comment.
Thanks for the update @satwiksps ! looks good, I just have one crucial question on the standard z-scoring formulas again, please check 🙏
manuelgloeckler
left a comment
There was a problem hiding this comment.
Thanks for you contribution.
I think the formula is still a bit off (but it also was never very clearly defined by us anyway (: ).
I do have a minor suggestion on t he mean,var buffers as well as the gaussian baseline test, which should be addressed (see comments). Once this done, we can merge it :)
Kind regards,
Manuel
|
Hi @manuelgloeckler and @satwiksps I tested this locally using a linear Gaussian test with shifted prior to U(95, 105). We have the option to just z-score the time vector or to additionally use the Gaussian baseline assumption. It turns out that the formula that works best is the one that normalized the time vector to 0 across time (called I also tested the Gaussian baseline option, once with the formulas dirived from the Flow Matching velocity objective (
I added all options as options in the internal code for and a smoke test comparing all these options. But just for reference. In the next commit I will clean things up. So, I suggest we go with the |
- fix bug with z_score_x vs y mapping in kwargs setup - fix formulas after empirical test with smoke tests
- more wrong test appear because of previously silent kwargs failures.
manuelgloeckler
left a comment
There was a problem hiding this comment.
Thanks!
The implementation looks good. I am just a bit confused/concerned about the assumed integration direction. But I can be wrong there, can you point me to the part were this "switch" happens?
|
@manuelgloeckler thanks for the review and checking again the integration direction. It should now be fixed. I will run slow tests again to make sure all is clean, and then we can merge / do you approve? |
janfb
left a comment
There was a problem hiding this comment.
Slow tests are passing and formulas have been clarified. Thanks again @satwiksps for the initial hard work on this one! 👏 🚀
I will leave this final review and approval to @manuelgloeckler !
manuelgloeckler
left a comment
There was a problem hiding this comment.
Great, thanks all!

Description
This PR re-introduces z-scoring for Flow Matching estimators using a time-dependent normalization approach and adds a Gaussian Baseline for improved training stability.
As discussed in #1623, standard z-scoring is problematic because the network input evolves from data to noise. This implementation provides two distinct normalization modes to handle this evolution while maintaining training stability.
Corrected Normalization Statistics:$t=0$ as Data and $t=1$ as Noise, the statistics are handled based on the chosen mode:
Since we define
Gaussian Baseline ($N(0, 1)$ across the entire path. The drift signal is handled by the hard-coded affine baseline.
gaussian_baseline=True): Normalizes inputs toVariance Only ($t=0$ . This ensures the network can still learn the drift signal when no baseline is used.
gaussian_baseline=False): Normalizes variance while preserving the raw data location atGaussian Baseline:
We implemented an affine vector field baseline (enabled by default). The network now learns the residual vector field with respect to the optimal Gaussian probability path, significantly improving convergence on shifted datasets.
Related Issues/PRs
Changes
sbi/neural_nets/net_builders/vector_field_nets.py: Updatedbuild_vector_field_estimatorto calculate training data statistics, accept thegaussian_baselineflag, and pass them to the estimator.sbi/neural_nets/estimators/flowmatching_estimator.py:mean_1andstd_1as buffers and expanded them to matchinput_shapeto ensure compatibility with multi-dimensional data in CI.forward()to support both Gaussian Baseline (residual learning) and Variance-only (signal preserving) modes.1e-5) to variance calculations to prevent division-by-zero errors.tests/linearGaussian_vector_field_test.py:test_fmpe_time_dependent_z_scoring_integration: Verifies statistics population, buffer registration, and forward pass shapes.test_fmpe_shifted_data_gaussian_baseline: Verifies that the Gaussian Baseline outperforms variance-only scaling on shifted data (Verification
Verification
gaussian_baseline=Trueachieves lower validation loss and faster convergence than variance-only scaling on a shifted 1D prior (z_score_x='independent'.sbibenchmarks locally (pytest --bm --bm-mode fmpe) to check for stability and performance. All 12 tests passed successfully.