fix: fix tensor grouping bug & optimize MMD calculation in causal_prediction/algorithms/regularization.py#1371
fix: fix tensor grouping bug & optimize MMD calculation in causal_prediction/algorithms/regularization.py#1371JPZ4-5 wants to merge 8 commits intopy-why:mainfrom
Conversation
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
There was a problem hiding this comment.
Pull request overview
This PR addresses critical bugs in the Regularizer class and introduces performance optimizations to the MMD calculation in causal prediction algorithms. The changes aim to fix a tensor grouping bug that caused incorrect penalty calculations and improve computational efficiency by ~40%.
Key Changes:
- Replaced manual hashing-based grouping logic with
torch.unique(dim=0, return_inverse=True)to fix dictionary key collision issues across environments - Introduced
_optimized_mmd_penaltymethod with algebraic optimization to reduce control flow overhead - Added
_compute_conditional_penaltyhelper to centralize conditional penalty computation logic - Enforced fp64 precision during MMD accumulation for numerical stability
- Added
use_optimizationparameter (defaulting to False) to allow gradual migration to optimized implementations
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Thank you very much @JPZ4-5 for this contribution. The PR looks promising. @jivatneet could you take a look as well? (thank you!) |
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
|
|
It seems CI fail with System.IO.IOException: No space left on device. It looks like the runner ran out of disk space. Could you please trigger a re-run? |
|
This PR is stale because it has been open for 60 days with no activity. |
|
Somehow, I can't retrigger it. @amit-sharma can you try? @jivatneet any chance to also take a look? |
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
|
The GCM test failure appears unrelated to my changes in This PR only modifies causal_prediction regularization code with no intersection to the GCM test paths. Ready to help investigate further if needed! |
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
|
Hey! I noticed the CI is failing due to pyproject.toml and poetry.lock being out of sync. (e.g., version of causal-learn) So I reverted my local changes to those files. My code doesn't touch dependencies, so I think main just needs a refresh. |
|
I'm sorry for the delay! @JPZ4-5 thanks for the great PR and for optimizing the code!
Your comment seems to be consistent with the original code here. I agree the suggested change is inconsistent with the paper.
Which branch in
|
This PR refactors the
Regularizerclass indowhy/causal_prediction/algorithms/regularization.pyto fix a critical logic error in cross-environment grouping and significantly improve computational efficiency. Given that this issue renders the current implementation mathematically incorrect and potentially harmful to model performance without raising errors, prompt review are highly recommended.Key Improvements
Replaced Grouping Logic:
factorsvector and dot product (grouping_data @ factors). Since GPU matrix multiplication (@) does not supportlongtypes, this required inefficient type casting betweenfloatandlong.torch.unique(dim=0, return_inverse=True)to handle grouping. This method is more robust, concise, and leverages native PyTorch optimizations without unnecessary type conversions.Bug Fix (Dictionary Key Issue):
tensor(1)from Env0 andtensor(1)from Env1) were treated as distinct objects. Consequently, incorrect MMD noise was added to the penalty because keys failed to collide across environments (as shown in the debug screenshot, identical keys from differentenvswere treated as different groups, leading to a wrong bigger penalty).torch.uniqueindices (or ensuring scalar keys are handled by value), ensuring data from different environments is correctly merged into the same pool.float64precision, casting back to the environment's default dtype (e.g.,float32) only after calculation.Empirical evidence and standard parameter search spaces suggest
gammais often very small (float32can lead to vanishing penalty terms or precision loss.float64ensures sufficient precision for the penalty accumulation.Benchmark: In local testing, this PR resulted in an approximate 40% speedup in training throughput (increasing from 2.5 it/s to 3.5 it/s). All 6 cases have tested.