Skip to content

[mxfp8 training] add TP warning#2562

Open
danielvegamyhre wants to merge 1 commit intomainfrom
tpwarning
Open

[mxfp8 training] add TP warning#2562
danielvegamyhre wants to merge 1 commit intomainfrom
tpwarning

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Mar 12, 2026

Summary

Optional extra details

  • Due to how Dtensor linear uses "composite implicit autograd" it decomposes linear ops into aten.t + aten.mm, going straight through __torch_dispatch__ instead of first going through __torch_function__. This prevents our subclass from intercepting the linear op to dispatch to _to_mxfp8_then_scaled_mm autograd func. We cannot intercept at the __torch_dispatch__ level because then autograd would not capture the backward for our autograd func we dispatch to.
  • In contrast, grouped_mm does not have this problem, as it has an explicitly registered backward (compose explicit autograd) so we always see aten._grouped_mm in __torch_function__ and can intercept.
  • thanks @pianpwk for the help with this!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants