-
Notifications
You must be signed in to change notification settings - Fork 296
feat(archon): Add tree training support for Archon engine #912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @nuzant, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Archon engine by introducing robust support for tree training. This feature optimizes the training process for large language models by efficiently handling sequences with common prefixes through a trie data structure. The changes ensure that the Archon engine can leverage advanced attention mechanisms, specifically FlexAttention and Triton-based tree attention, leading to more performant and resource-efficient model training. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant new functionality by adding tree training support to the Archon engine, which is a valuable addition for training efficiency. The changes are comprehensive, affecting the core engine, runner, and model implementations, and include support for both FlexAttention and Triton backends. The inclusion of tests for this new feature is also a great practice.
My review highlights a few areas for improvement:
- Configuration Robustness: There's an opportunity to make the configuration handling for
attn_typemore robust to prevent potential runtime errors from mismatches. - API Change: The
loss_fnsignature has been implicitly changed, which could be a breaking change for existing users of the engine. - Test Correctness: The new tests rely on very high tolerances for numerical comparisons, which may mask correctness issues in both the forward and backward passes.
Addressing these points will help ensure the new tree training feature is both robust and numerically correct. Overall, this is a well-executed and important feature addition.
Add tree training capabilities to Archon engine following FSDP patterns: - Add tree training imports and TrieNode field to ArchonTrainContext - Update initialize() with CP validation and pad_to_maximum enforcement - Add tree training path in _prepare_mb_list() using build_packed_tree_batch() - Extract trie_node in _prepare_mb_inputs() and store in context - Update _compute_logprobs_and_loss() with tree training using gather_packed_tree_logprobs_entropy() and gather_packed_tree_vocab_stats() - Update _compute_forward_result() with _gather_packed_tree_logprobs() - Update forward_batch() with merge_packed_tree_results() - Update SequentialRunner to build tree attention data before model forward - Add block_mask and triton_attn_data parameters to Qwen2 model layers - Update VarlenAttentionWrapper and SDPAWrapper for tree attention - Add Archon engine to tree training tests Co-Authored-By: Claude Opus 4.5 <[email protected]>
Make sure `attn_type=="tree"` and `enable_tree_training` is enabled or disabled at the same time. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
76ec4fa to
f73414d
Compare
rchardx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…backend selection Simplify tree training code added in #912 by extracting helpers and centralizing backend selection in the tree_attn package. Key changes: - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output from archon_engine output methods - Remove thin wrappers _compute_logprobs_entropy, _compute_logprobs - Add TreeAttnMeta.from_trie() factory for Archon backend selection - Add build_tree_attn_kwargs() factory for FSDP/Megatron backend selection - Unify tree attention parameter as tree_attn_meta across Archon model stack - Prefix FSDP/Megatron kwargs with tree_ to avoid HF collisions - Move BLOCK_SIZE alignment validation into build_packed_tree_batch
…backend selection Simplify tree training code added in #912 by extracting helpers and centralizing backend selection in the tree_attn package. Key changes: - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output from archon_engine output methods - Remove thin wrappers _compute_logprobs_entropy, _compute_logprobs - Add TreeAttnMeta.from_trie() factory for Archon backend selection - Add build_tree_attn_kwargs() factory for FSDP/Megatron backend selection - Unify tree attention parameter as tree_attn_meta across Archon model stack - Prefix FSDP/Megatron kwargs with tree_ to avoid HF collisions - Move BLOCK_SIZE alignment validation into build_packed_tree_batch
… attention package Consolidate tree attention metadata into TreeAttentionMeta dataclass, pass it end-to-end through model layers, and restructure the Archon attention module into a package with clear separation of concerns. Key changes: - Add TreeAttentionMeta with from_trie() factory to auto-select backend - Pass tree_attn_meta through Model → TransformerBlock → Attention in Archon models (qwen2, qwen3) - Convert attention.py into attention/ package (sdpa.py, varlen.py) - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output helpers in ArchonEngine - Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding with tree_ prefix to avoid HF kwarg collisions - Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel (caused InductorError with torch.compile backward) - Fix missing vocab_min/vocab_max CP gather in Archon engine - Move BLOCK_SIZE alignment into build_packed_tree_batch via parallel_size parameter with math.lcm Refs: #912
… attention package Consolidate tree attention metadata into TreeAttentionMeta dataclass, pass it end-to-end through model layers, and restructure the Archon attention module into a package with clear separation of concerns. Key changes: - Add TreeAttentionMeta with from_trie() factory to auto-select backend - Pass tree_attn_meta through Model → TransformerBlock → Attention in Archon models (qwen2, qwen3) - Convert attention.py into attention/ package (sdpa.py, varlen.py) - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output helpers in ArchonEngine - Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding with tree_ prefix to avoid HF kwarg collisions - Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel (caused InductorError with torch.compile backward) - Fix missing vocab_min/vocab_max CP gather in Archon engine - Move BLOCK_SIZE alignment into build_packed_tree_batch via parallel_size parameter with math.lcm Refs: #912
… attention package Consolidate tree attention metadata into TreeAttentionMeta dataclass, pass it end-to-end through model layers, and restructure the Archon attention module into a package with clear separation of concerns. Key changes: - Add TreeAttentionMeta with from_trie() factory to auto-select backend - Pass tree_attn_meta through Model → TransformerBlock → Attention in Archon models (qwen2, qwen3) - Convert attention.py into attention/ package (sdpa.py, varlen.py) - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output helpers in ArchonEngine - Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding with tree_ prefix to avoid HF kwarg collisions - Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel (caused InductorError with torch.compile backward) - Fix missing vocab_min/vocab_max CP gather in Archon engine - Move BLOCK_SIZE alignment into build_packed_tree_batch via parallel_size parameter with math.lcm Refs: #912
… attention package Consolidate tree attention metadata into TreeAttentionMeta dataclass, pass it end-to-end through model layers, and restructure the Archon attention module into a package with clear separation of concerns. Key changes: - Add TreeAttentionMeta with from_trie() factory to auto-select backend - Pass tree_attn_meta through Model → TransformerBlock → Attention in Archon models (qwen2, qwen3) - Convert attention.py into attention/ package (sdpa.py, varlen.py) - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output helpers in ArchonEngine - Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding with tree_ prefix to avoid HF kwarg collisions - Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel (caused InductorError with torch.compile backward) - Fix missing vocab_min/vocab_max CP gather in Archon engine - Move BLOCK_SIZE alignment into build_packed_tree_batch via parallel_size parameter with math.lcm Refs: #912
… attention package (#920) Consolidate tree attention metadata into TreeAttentionMeta dataclass, pass it end-to-end through model layers, and restructure the Archon attention module into a package with clear separation of concerns. Key changes: - Add TreeAttentionMeta with from_trie() factory to auto-select backend - Pass tree_attn_meta through Model → TransformerBlock → Attention in Archon models (qwen2, qwen3) - Convert attention.py into attention/ package (sdpa.py, varlen.py) - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output helpers in ArchonEngine - Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding with tree_ prefix to avoid HF kwarg collisions - Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel (caused InductorError with torch.compile backward) - Fix missing vocab_min/vocab_max CP gather in Archon engine - Move BLOCK_SIZE alignment into build_packed_tree_batch via parallel_size parameter with math.lcm Refs: #912
Summary
Description
This PR adds tree training capabilities to the Archon engine, allowing multiple sequences sharing common prefixes to be packed into a trie structure for efficient training.
Key Changes
Core Engine (
archon_engine.py):trie_nodefield toArchonTrainContextinitialize()with CP validation andpad_to_maximumenforcement_prepare_mb_list()usingbuild_packed_tree_batch()_compute_logprobs_and_loss()with tree training usinggather_packed_tree_logprobs_entropy()andgather_packed_tree_vocab_stats()_compute_forward_result()andforward_batch()for tree training result handlingRunner (
archon_runner.py):SequentialRunnerto build tree attention data before model forwardblock_maskandtriton_attn_datato modelModel Files:
block_maskandtriton_attn_dataparameters to Qwen2/Qwen3 model layers (Attention, TransformerBlock, Model)VarlenAttentionWrapperandSDPAWrapperfor tree attention supportNew Module (
module_archon.py):Tests:
Type of Change
Checklist
jb build docs/gemini review)🤖 Generated with Claude Code