-
Notifications
You must be signed in to change notification settings - Fork 296
refactor(tree-attn): simplify tree attention plumbing and restructure attention package #920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @rchardx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the tree training codebase by streamlining the handling of tree attention metadata and centralizing backend selection logic. The changes aim to improve code clarity, maintainability, and consistency across different training engines (Archon, FSDP, Megatron) by unifying the parameter interface for tree attention and extracting common utility functions. This cleanup makes the tree training implementation more modular and easier to extend. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides a significant and well-executed refactoring of the tree attention training logic. The changes successfully simplify the codebase by encapsulating backend selection, unifying the parameter interface across different engines with tree_attn_meta, and improving the overall structure by extracting focused helper methods. Moving the BLOCK_SIZE alignment validation into build_packed_tree_batch is also a good change that centralizes responsibility. The code is now cleaner and more maintainable. I have a couple of medium-severity suggestions regarding the change from ValueError to assert for critical checks, which could be less robust in production environments where assertions might be disabled.
8144749 to
b0c6d67
Compare
3b77aaf to
41b2725
Compare
41b2725 to
097578b
Compare
… attention package Consolidate tree attention metadata into TreeAttentionMeta dataclass, pass it end-to-end through model layers, and restructure the Archon attention module into a package with clear separation of concerns. Key changes: - Add TreeAttentionMeta with from_trie() factory to auto-select backend - Pass tree_attn_meta through Model → TransformerBlock → Attention in Archon models (qwen2, qwen3) - Convert attention.py into attention/ package (sdpa.py, varlen.py) - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output helpers in ArchonEngine - Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding with tree_ prefix to avoid HF kwarg collisions - Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel (caused InductorError with torch.compile backward) - Fix missing vocab_min/vocab_max CP gather in Archon engine - Move BLOCK_SIZE alignment into build_packed_tree_batch via parallel_size parameter with math.lcm Refs: #912
097578b to
70e4d2e
Compare
garrett4wade
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Waiting for CI.
Description
Cleanup of #912. Consolidate tree attention metadata into
TreeAttentionMetadataclass, pass it end-to-end through model layers, restructure the Archon attention module into a package, and fix two bugs (triton kernel InductorError, missing CP gather for vocab stats).Related Issue
Follow-up to #912
Type of Change
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
N/A
Additional Context
Key changes:
TreeAttentionMetawithfrom_trie()factory for Archon backend selectiontree_attn_metathrough Model → TransformerBlock → Attention (replaces separateblock_mask/triton_attn_dataparams)attention.pyintoattention/package (sdpa.py,varlen.py)_gather_actor_train_outputs,_gather_actor_forward_output,_gather_critic_outputhelpers in ArchonEnginebuild_tree_attn_kwargs()for FSDP/Megatron dict-based forwarding withtree_prefix to avoid HF kwarg collisionsstride_obin triton_tree_attn_bwd_preprocesskernel (causedInductorErrorwithtorch.compilebackward)vocab_min/vocab_maxCP gather in Archon engineBLOCK_SIZEalignment intobuild_packed_tree_batchviaparallel_sizeparameter withmath.lcm