Skip to content

Conversation

@nuzant
Copy link
Collaborator

@nuzant nuzant commented Feb 9, 2026

Summary

  • Add tree training support to Archon engine following FSDP engine patterns
  • Enable efficient training with shared prefix sequences via trie structure
  • Support both flex_attention (BlockMask) and Triton tree attention backends

Description

This PR adds tree training capabilities to the Archon engine, allowing multiple sequences sharing common prefixes to be packed into a trie structure for efficient training.

Key Changes

Core Engine (archon_engine.py):

  • Add tree training imports and trie_node field to ArchonTrainContext
  • Update initialize() with CP validation and pad_to_maximum enforcement
  • Add tree training path in _prepare_mb_list() using build_packed_tree_batch()
  • Update _compute_logprobs_and_loss() with tree training using gather_packed_tree_logprobs_entropy() and gather_packed_tree_vocab_stats()
  • Update _compute_forward_result() and forward_batch() for tree training result handling

Runner (archon_runner.py):

  • Update SequentialRunner to build tree attention data before model forward
  • Pass block_mask and triton_attn_data to model

Model Files:

  • Add block_mask and triton_attn_data parameters to Qwen2/Qwen3 model layers (Attention, TransformerBlock, Model)
  • Update VarlenAttentionWrapper and SDPAWrapper for tree attention support

New Module (module_archon.py):

  • Tree attention helper functions for Archon engine

Tests:

  • Add Archon engine to tree training tests (forward and forward-backward)

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

🤖 Generated with Claude Code

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nuzant, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Archon engine by introducing robust support for tree training. This feature optimizes the training process for large language models by efficiently handling sequences with common prefixes through a trie data structure. The changes ensure that the Archon engine can leverage advanced attention mechanisms, specifically FlexAttention and Triton-based tree attention, leading to more performant and resource-efficient model training.

Highlights

  • Archon Engine Tree Training Support: Added comprehensive tree training capabilities to the Archon engine, enabling efficient training with shared prefix sequences using a trie structure.
  • Attention Backend Integration: Integrated support for both FlexAttention (BlockMask) and Triton tree attention backends within the Archon engine and model layers.
  • Core Engine Modifications: Updated ArchonTrainContext with a trie_node field, modified initialize() for tree training validation and pad_to_maximum enforcement, and adjusted _prepare_mb_list, _compute_logprobs_and_loss, and _compute_forward_result to handle tree-specific logic.
  • Runner Updates: The SequentialRunner now builds and passes tree attention metadata (block_mask or triton_attn_data) to the model's forward pass when tree training is active.
  • Model Layer Adaptations: Qwen2 and Qwen3 model attention, transformer block, and main model layers were updated to accept and utilize block_mask and triton_attn_data for tree attention.
  • New Module for Archon Tree Attention: Introduced areal/models/tree_attn/module_archon.py to house Archon-specific tree attention helper functions and the TreeAttentionWrapper.
  • CLI and Documentation Updates: The attn_type configuration in ArchonEngineConfig now includes 'tree' as a valid choice, and the CLI reference documentation has been updated accordingly.
  • Testing Expansion: Existing tree training tests for forward and forward-backward passes have been extended to include the Archon engine.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/api/cli_args.py
    • Added 'tree' as a new choice for the attn_type parameter in ArchonEngineConfig.
    • Updated the help text for attn_type to indicate its use for tree training.
  • areal/experimental/engine/archon_engine.py
    • Imported necessary tree attention functional and module components, including TrieNode.
    • Added a trie_node field to ArchonTrainContext and implemented a to_dict() method to prevent recursion errors during serialization.
    • Modified the initialize() method to validate tree training compatibility with pipeline/context parallelism (disallowing them) and to enforce pad_to_maximum=True for block mask alignment.
    • Updated forward_batch() to correctly determine batch_size for tree training scenarios.
    • Adjusted process_output() to use merge_packed_tree_results when tree training is enabled.
    • Refactored _prepare_mb_inputs() to extract trie_node from inputs, handle labels differently for tree training (computed via trie), and pass trie_node to ArchonTrainContext.
    • Introduced a dedicated tree training path in _prepare_mb_list() that builds packed tree batches and raises an error if context parallelism is active.
    • Enhanced _compute_logprobs_and_loss() to support tree training by using gather_packed_tree_vocab_stats and gather_packed_tree_logprobs_entropy, and to handle dummy tries for gradient connection. Also added _get_vocab_min_max_logits for non-tree paths.
    • Modified _compute_forward_result() to use _gather_packed_tree_logprobs for tree training.
    • Updated _create_model_structure() to automatically set attn_type to 'tree' if tree training is enabled, overriding the configured value.
  • areal/experimental/engine/archon_runner.py
    • Imported TRITON_AVAILABLE, USE_TRITON_TREE_ATTN, build_block_mask_from_trie, and build_triton_attn_data_from_trie for tree attention.
    • Modified the run() method to dynamically build block_mask or triton_attn_data from ctx.trie_node before the model's forward pass.
    • Adjusted run() to pass the generated block_mask and triton_attn_data as keyword arguments to the model.
    • Introduced logic to create dummy cu_seqlens and max_seqlen for model compatibility when tree training is active, as tree attention uses block masks.
    • Added steps to release tree attention metadata after the forward pass to manage memory.
    • Updated the handling of ctx to use ctx.to_dict() for proper serialization and results.append to correctly handle dictionary results from tree training.
  • areal/experimental/models/archon/attention.py
    • Imported TreeAttentionWrapper and FLEX_ATTENTION_AVAILABLE from the new module_archon.
    • Added TreeAttentionWrapper and FLEX_ATTENTION_AVAILABLE to the module's __all__ export list.
    • Modified VarlenAttentionWrapper.forward and SDPAWrapper.forward to accept and ignore **kwargs (including block_mask and triton_attn_data) for API compatibility with tree attention.
  • areal/experimental/models/archon/qwen2/model/model.py
    • Added from __future__ import annotations for improved type hinting.
    • Imported TreeAttentionWrapper for tree attention support.
    • Conditionally imported BlockMask from torch.nn.attention.flex_attention and TreeAttentionData from triton_kernel.
    • Updated Qwen2Attention.__init__ to instantiate TreeAttentionWrapper when model_args.attn_type is 'tree'.
    • Modified Qwen2Attention.forward, Qwen2DecoderLayer.forward, and Qwen2Model.forward signatures to accept block_mask and triton_attn_data and pass them down the call stack.
  • areal/experimental/models/archon/qwen3/model/model.py
    • Added from __future__ import annotations for improved type hinting.
    • Imported TreeAttentionWrapper for tree attention support.
    • Conditionally imported BlockMask from torch.nn.attention.flex_attention and TreeAttentionData from triton_kernel.
    • Updated Qwen3Attention.__init__ to instantiate TreeAttentionWrapper when model_args.attn_type is 'tree'.
    • Modified Qwen3Attention.forward, Qwen3DecoderLayer.forward, and Qwen3Model.forward signatures to accept block_mask and triton_attn_data and pass them down the call stack.
  • areal/experimental/models/archon/varlen_attention.py
    • Modified VarlenAttentionWrapper.forward to accept and ignore **kwargs (including block_mask and triton_attn_data) for API compatibility with tree attention.
  • areal/models/tree_attn/module.py
    • Imported FLEX_ATTENTION_AVAILABLE and TreeAttentionWrapper from areal.models.tree_attn.module_archon.
    • Added these Archon-specific exports to the module's __all__ list.
  • areal/models/tree_attn/module_archon.py
    • New file: Implements TreeAttentionWrapper, a wrapper for tree attention in the Archon Engine.
    • The wrapper provides a unified interface for using either FlexAttention with BlockMask or a Triton kernel for tree attention.
    • It handles the passing of block_mask or triton_attn_data as keyword arguments during the forward pass.
    • Includes logic to check for the availability of FlexAttention and Triton backends and raises errors if no valid backend is found.
  • areal/tests/test_tree_training.py
    • Imported ArchonEngine to include it in tree training tests.
    • Updated _collect_gradients and _collect_parameters functions to correctly handle ArchonEngine instances by iterating over engine.model_parts.
    • Modified _create_engine to include ArchonEngine as a possible return type.
    • Added 'archon' to the engine_type parameterization for test_tree_training_forward and test_tree_training_forward_backward.
    • Ensured areal.models.tree_attn.module_archon.USE_TRITON_TREE_ATTN is correctly set during tests.
  • docs/cli_reference.md
    • Updated the table for ArchonEngineConfig to include 'tree' as a valid choice for attn_type.
    • Revised the description for attn_type to explicitly mention using 'tree' for tree training.
  • examples/vlm_npu/README.md
    • Applied minor formatting adjustments, such as adding blank lines for improved readability.
Activity
  • The pull request was created by nuzant to add tree training support to the Archon engine.
  • A detailed summary and description of the key changes, including modifications to the core engine, runner, and model files, were provided by the author.
  • The author completed a checklist indicating that new functionality was added and tested, and the branch is up to date with main.
  • A test plan was included, outlining specific pytest commands to verify forward and forward-backward tree training, and to check for context parallelism incompatibility.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant new functionality by adding tree training support to the Archon engine, which is a valuable addition for training efficiency. The changes are comprehensive, affecting the core engine, runner, and model implementations, and include support for both FlexAttention and Triton backends. The inclusion of tests for this new feature is also a great practice.

My review highlights a few areas for improvement:

  1. Configuration Robustness: There's an opportunity to make the configuration handling for attn_type more robust to prevent potential runtime errors from mismatches.
  2. API Change: The loss_fn signature has been implicitly changed, which could be a breaking change for existing users of the engine.
  3. Test Correctness: The new tests rely on very high tolerances for numerical comparisons, which may mask correctness issues in both the forward and backward passes.

Addressing these points will help ensure the new tree training feature is both robust and numerically correct. Overall, this is a well-executed and important feature addition.

@nuzant nuzant marked this pull request as ready for review February 9, 2026 07:05
nuzant and others added 3 commits February 9, 2026 20:29
Add tree training capabilities to Archon engine following FSDP patterns:

- Add tree training imports and TrieNode field to ArchonTrainContext
- Update initialize() with CP validation and pad_to_maximum enforcement
- Add tree training path in _prepare_mb_list() using build_packed_tree_batch()
- Extract trie_node in _prepare_mb_inputs() and store in context
- Update _compute_logprobs_and_loss() with tree training using
  gather_packed_tree_logprobs_entropy() and gather_packed_tree_vocab_stats()
- Update _compute_forward_result() with _gather_packed_tree_logprobs()
- Update forward_batch() with merge_packed_tree_results()
- Update SequentialRunner to build tree attention data before model forward
- Add block_mask and triton_attn_data parameters to Qwen2 model layers
- Update VarlenAttentionWrapper and SDPAWrapper for tree attention
- Add Archon engine to tree training tests

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Make sure `attn_type=="tree"` and `enable_tree_training` is enabled or disabled at the same time.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@nuzant nuzant added the safe-to-test Ready to run unit-tests in a PR. label Feb 9, 2026
@nuzant nuzant temporarily deployed to AReaL-unittests February 9, 2026 12:49 — with GitHub Actions Inactive
Copy link
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rchardx rchardx merged commit e03f32f into main Feb 10, 2026
8 checks passed
@rchardx rchardx deleted the mzy/archon-tree branch February 10, 2026 05:47
rchardx added a commit that referenced this pull request Feb 10, 2026
…backend selection

Simplify tree training code added in #912 by extracting helpers
and centralizing backend selection in the tree_attn package.

Key changes:
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
  _gather_critic_output from archon_engine output methods
- Remove thin wrappers _compute_logprobs_entropy, _compute_logprobs
- Add TreeAttnMeta.from_trie() factory for Archon backend selection
- Add build_tree_attn_kwargs() factory for FSDP/Megatron backend selection
- Unify tree attention parameter as tree_attn_meta across Archon model stack
- Prefix FSDP/Megatron kwargs with tree_ to avoid HF collisions
- Move BLOCK_SIZE alignment validation into build_packed_tree_batch
rchardx added a commit that referenced this pull request Feb 10, 2026
…backend selection

Simplify tree training code added in #912 by extracting helpers
and centralizing backend selection in the tree_attn package.

Key changes:
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
  _gather_critic_output from archon_engine output methods
- Remove thin wrappers _compute_logprobs_entropy, _compute_logprobs
- Add TreeAttnMeta.from_trie() factory for Archon backend selection
- Add build_tree_attn_kwargs() factory for FSDP/Megatron backend selection
- Unify tree attention parameter as tree_attn_meta across Archon model stack
- Prefix FSDP/Megatron kwargs with tree_ to avoid HF collisions
- Move BLOCK_SIZE alignment validation into build_packed_tree_batch
rchardx added a commit that referenced this pull request Feb 10, 2026
… attention package

Consolidate tree attention metadata into TreeAttentionMeta dataclass,
pass it end-to-end through model layers, and restructure the Archon
attention module into a package with clear separation of concerns.

Key changes:
- Add TreeAttentionMeta with from_trie() factory to auto-select backend
- Pass tree_attn_meta through Model → TransformerBlock → Attention
  in Archon models (qwen2, qwen3)
- Convert attention.py into attention/ package (sdpa.py, varlen.py)
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
  _gather_critic_output helpers in ArchonEngine
- Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding
  with tree_ prefix to avoid HF kwarg collisions
- Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel
  (caused InductorError with torch.compile backward)
- Fix missing vocab_min/vocab_max CP gather in Archon engine
- Move BLOCK_SIZE alignment into build_packed_tree_batch via
  parallel_size parameter with math.lcm

Refs: #912
rchardx added a commit that referenced this pull request Feb 10, 2026
… attention package

Consolidate tree attention metadata into TreeAttentionMeta dataclass,
pass it end-to-end through model layers, and restructure the Archon
attention module into a package with clear separation of concerns.

Key changes:
- Add TreeAttentionMeta with from_trie() factory to auto-select backend
- Pass tree_attn_meta through Model → TransformerBlock → Attention
  in Archon models (qwen2, qwen3)
- Convert attention.py into attention/ package (sdpa.py, varlen.py)
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
  _gather_critic_output helpers in ArchonEngine
- Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding
  with tree_ prefix to avoid HF kwarg collisions
- Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel
  (caused InductorError with torch.compile backward)
- Fix missing vocab_min/vocab_max CP gather in Archon engine
- Move BLOCK_SIZE alignment into build_packed_tree_batch via
  parallel_size parameter with math.lcm

Refs: #912
rchardx added a commit that referenced this pull request Feb 10, 2026
… attention package

Consolidate tree attention metadata into TreeAttentionMeta dataclass,
pass it end-to-end through model layers, and restructure the Archon
attention module into a package with clear separation of concerns.

Key changes:
- Add TreeAttentionMeta with from_trie() factory to auto-select backend
- Pass tree_attn_meta through Model → TransformerBlock → Attention
  in Archon models (qwen2, qwen3)
- Convert attention.py into attention/ package (sdpa.py, varlen.py)
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
  _gather_critic_output helpers in ArchonEngine
- Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding
  with tree_ prefix to avoid HF kwarg collisions
- Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel
  (caused InductorError with torch.compile backward)
- Fix missing vocab_min/vocab_max CP gather in Archon engine
- Move BLOCK_SIZE alignment into build_packed_tree_batch via
  parallel_size parameter with math.lcm

Refs: #912
rchardx added a commit that referenced this pull request Feb 10, 2026
… attention package

Consolidate tree attention metadata into TreeAttentionMeta dataclass,
pass it end-to-end through model layers, and restructure the Archon
attention module into a package with clear separation of concerns.

Key changes:
- Add TreeAttentionMeta with from_trie() factory to auto-select backend
- Pass tree_attn_meta through Model → TransformerBlock → Attention
  in Archon models (qwen2, qwen3)
- Convert attention.py into attention/ package (sdpa.py, varlen.py)
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
  _gather_critic_output helpers in ArchonEngine
- Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding
  with tree_ prefix to avoid HF kwarg collisions
- Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel
  (caused InductorError with torch.compile backward)
- Fix missing vocab_min/vocab_max CP gather in Archon engine
- Move BLOCK_SIZE alignment into build_packed_tree_batch via
  parallel_size parameter with math.lcm

Refs: #912
garrett4wade pushed a commit that referenced this pull request Feb 10, 2026
… attention package (#920)

Consolidate tree attention metadata into TreeAttentionMeta dataclass,
pass it end-to-end through model layers, and restructure the Archon
attention module into a package with clear separation of concerns.

Key changes:
- Add TreeAttentionMeta with from_trie() factory to auto-select backend
- Pass tree_attn_meta through Model → TransformerBlock → Attention
  in Archon models (qwen2, qwen3)
- Convert attention.py into attention/ package (sdpa.py, varlen.py)
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
  _gather_critic_output helpers in ArchonEngine
- Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding
  with tree_ prefix to avoid HF kwarg collisions
- Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel
  (caused InductorError with torch.compile backward)
- Fix missing vocab_min/vocab_max CP gather in Archon engine
- Move BLOCK_SIZE alignment into build_packed_tree_batch via
  parallel_size parameter with math.lcm

Refs: #912
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Ready to run unit-tests in a PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants