Skip to content

fix: fix tensor grouping bug & optimize MMD calculation in causal_prediction/algorithms/regularization.py#1371

Open
JPZ4-5 wants to merge 8 commits intopy-why:mainfrom
JPZ4-5:fix/cacm-reg
Open

fix: fix tensor grouping bug & optimize MMD calculation in causal_prediction/algorithms/regularization.py#1371
JPZ4-5 wants to merge 8 commits intopy-why:mainfrom
JPZ4-5:fix/cacm-reg

Conversation

@JPZ4-5
Copy link
Contributor

@JPZ4-5 JPZ4-5 commented Nov 27, 2025

This PR refactors the Regularizer class in dowhy/causal_prediction/algorithms/regularization.py to fix a critical logic error in cross-environment grouping and significantly improve computational efficiency. Given that this issue renders the current implementation mathematically incorrect and potentially harmful to model performance without raising errors, prompt review are highly recommended.

Key Improvements

  1. Replaced Grouping Logic:

    • Legacy: Relied on a manual hashing approach using a factors vector and dot product (grouping_data @ factors). Since GPU matrix multiplication (@) does not support long types, this required inefficient type casting between float and long.
    • New: Adopted torch.unique(dim=0, return_inverse=True) to handle grouping. This method is more robust, concise, and leverages native PyTorch optimizations without unnecessary type conversions.
  2. Bug Fix (Dictionary Key Issue):

    • Issue: The legacy implementation used PyTorch Tensors as keys for Python dictionaries. In cross-environment settings, identical scalar tensors from different environments (e.g., tensor(1) from Env0 and tensor(1) from Env1) were treated as distinct objects. Consequently, incorrect MMD noise was added to the penalty because keys failed to collide across environments (as shown in the debug screenshot, identical keys from different envs were treated as different groups, leading to a wrong bigger penalty).
    • Fix: The new implementation naturally resolves this by utilizing torch.unique indices (or ensuring scalar keys are handled by value), ensuring data from different environments is correctly merged into the same pool.
图片
  1. Algebraic Optimization & Throughput:
    • Refactored the MMD Penalty calculation to use an algebraically optimized form instead of nested Python loops, which significantly reduces control flow overhead and improves GPU throughput.
    • Formula:

$$ \sum_{i=1}^n\sum_{j=i+1}^n (K_{ii}+K_{jj}-2K_{ij})=(n-1)\sum_{i=1}^n K_{ii}-2\sum_{i=1}^n\sum_{j=i+1}^n K_{ij} $$

  1. Numerical Stability (Enforced fp64):
    • Change: Forced MMD accumulation to use float64 precision, casting back to the environment's default dtype (e.g., float32) only after calculation.
      Empirical evidence and standard parameter search spaces suggest gamma is often very small ($10^{-5}$ to $10^{-7}$). Calculating Gaussian kernels with such small values in float32 can lead to vanishing penalty terms or precision loss. float64 ensures sufficient precision for the penalty accumulation.

Benchmark: In local testing, this PR resulted in an approximate 40% speedup in training throughput (increasing from 2.5 it/s to 3.5 it/s). All 6 cases have tested.

Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses critical bugs in the Regularizer class and introduces performance optimizations to the MMD calculation in causal prediction algorithms. The changes aim to fix a tensor grouping bug that caused incorrect penalty calculations and improve computational efficiency by ~40%.

Key Changes:

  • Replaced manual hashing-based grouping logic with torch.unique(dim=0, return_inverse=True) to fix dictionary key collision issues across environments
  • Introduced _optimized_mmd_penalty method with algebraic optimization to reduce control flow overhead
  • Added _compute_conditional_penalty helper to centralize conditional penalty computation logic
  • Enforced fp64 precision during MMD accumulation for numerical stability
  • Added use_optimization parameter (defaulting to False) to allow gradual migration to optimized implementations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@emrekiciman
Copy link
Member

Thank you very much @JPZ4-5 for this contribution. The PR looks promising. @jivatneet could you take a look as well? (thank you!)

JPZ4-5 added 2 commits November 29, 2025 23:05
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
@JPZ4-5
Copy link
Contributor Author

JPZ4-5 commented Nov 29, 2025

  1. Regarding E_eq_A=True Logic:
    Replacing torch.full(..., i) with attribute_labels[i] is suggested. However, according to the original CACM paper, when E_eq_A=True, the algorithm is explicitly designed to use the Environment index as the sensitive attribute, regardless of what constitutes the raw attribute_labels.
    Therefore, constructing the labels manually using the environment index i is the intended behavior.

  2. Code Fixes Applied:

    • Corrected features.dtype access.
    • Fixed the initialization of the covariance matrix in the else branch (using torch.zeros instead of .diag() to ensure correct shape for $N=1$).
    • Completed the missing docstrings.
    • Standardized the usage of torch.tensor vs tensor.
  3. Clarification on MMD Calculation & use_optimization:
    I want to clarify that the critical bug was solely in the grouping stage (using Tensors as dictionary keys), which I have fixed. (This sentence may help bot to understand)
    I retained the unoptimized path because it offers higher readability and facilitates easier extensibility for future custom kernels. Not all kernels may have a straightforward vectorized implementation for pooled data, and developers might prioritize readability/development efficiency over spending time on trivial algebraic optimizations for complex kernels(like me).
    The use_optimization flag allows developers to opt-in when they are using the standard gaussian_kernel (or others with clear efficiency gains from vectorization). This parameter can be easily toggled in the CACM class if needed.

Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
@JPZ4-5
Copy link
Contributor Author

JPZ4-5 commented Dec 5, 2025

It seems CI fail with System.IO.IOException: No space left on device. It looks like the runner ran out of disk space. Could you please trigger a re-run?

@github-actions
Copy link

github-actions bot commented Feb 4, 2026

This PR is stale because it has been open for 60 days with no activity.

@github-actions github-actions bot added the stale label Feb 4, 2026
@bloebp
Copy link
Member

bloebp commented Feb 12, 2026

Somehow, I can't retrigger it. @amit-sharma can you try?

@jivatneet any chance to also take a look?

@github-actions github-actions bot removed the stale label Feb 13, 2026
Copy link
Member

@amit-sharma amit-sharma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even I'm facing difficulty in triggering. Could be because the PR is older than a month?

JPZ4-5 added 3 commits February 14, 2026 01:11
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
@JPZ4-5
Copy link
Contributor Author

JPZ4-5 commented Feb 14, 2026

The GCM test failure appears unrelated to my changes in regularization.py and is likely a flaky test issue (marked with @flaky).

This PR only modifies causal_prediction regularization code with no intersection to the GCM test paths. Ready to help investigate further if needed!

Signed-off-by: JPZ4-5 <yinaibaizi@gmail.com>
@JPZ4-5
Copy link
Contributor Author

JPZ4-5 commented Feb 14, 2026

Hey! I noticed the CI is failing due to pyproject.toml and poetry.lock being out of sync. (e.g., version of causal-learn)

So I reverted my local changes to those files. My code doesn't touch dependencies, so I think main just needs a refresh.

@jivatneet
Copy link
Contributor

I'm sorry for the delay! @JPZ4-5 thanks for the great PR and for optimizing the code!

Regarding E_eq_A=True Logic

Your comment seems to be consistent with the original code here. I agree the suggested change is inconsistent with the paper.

Fixing the MMD Calculation

Which branch in regularization.py does this correspond to? Is it for E not eq A and not self.E_conditioned? As I can see, torch.unique was present in the original code -- maybe the inconsistency arises from elsewhere? Thanks for looking into this!

unique_attrs = torch.unique( attribute_labels[i][group_idx_indices] )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants