Skip to content

Conversation

@rchardx
Copy link
Collaborator

@rchardx rchardx commented Feb 10, 2026

Description

Cleanup of #912. Consolidate tree attention metadata into TreeAttentionMeta dataclass, pass it end-to-end through model layers, restructure the Archon attention module into a package, and fix two bugs (triton kernel InductorError, missing CP gather for vocab stats).

Related Issue

Follow-up to #912

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

N/A

Additional Context

Key changes:

  • Add TreeAttentionMeta with from_trie() factory for Archon backend selection
  • Pass tree_attn_meta through Model → TransformerBlock → Attention (replaces separate block_mask/triton_attn_data params)
  • Convert attention.py into attention/ package (sdpa.py, varlen.py)
  • Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output helpers in ArchonEngine
  • Add build_tree_attn_kwargs() for FSDP/Megatron dict-based forwarding with tree_ prefix to avoid HF kwarg collisions
  • Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel (caused InductorError with torch.compile backward)
  • Fix missing vocab_min/vocab_max CP gather in Archon engine
  • Move BLOCK_SIZE alignment into build_packed_tree_batch via parallel_size parameter with math.lcm

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @rchardx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the tree training codebase by streamlining the handling of tree attention metadata and centralizing backend selection logic. The changes aim to improve code clarity, maintainability, and consistency across different training engines (Archon, FSDP, Megatron) by unifying the parameter interface for tree attention and extracting common utility functions. This cleanup makes the tree training implementation more modular and easier to extend.

Highlights

  • Refactored Tree Training Logic: Extracted focused helper methods (_gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output) from the Archon engine's output processing, simplifying _compute_logprobs_and_loss and _compute_forward_result.
  • Unified Tree Attention Parameter Interface: Introduced a new TreeAttnMeta dataclass to encapsulate tree attention metadata (either BlockMask or TreeAttentionData). This replaces separate block_mask and triton_attn_data parameters across the Archon model stack, leading to a cleaner API.
  • Centralized Backend Selection: Added factory functions TreeAttnMeta.from_trie() for Archon and build_tree_attn_kwargs() for FSDP/Megatron to centralize the logic for selecting between Triton kernel and Flex Attention backends based on availability and configuration.
  • Improved Parameter Naming for FSDP/Megatron: Prefixed tree attention-related keyword arguments with tree_ (e.g., tree_block_mask, tree_triton_data) when passed to HuggingFace models in FSDP/Megatron engines to prevent potential collisions with HuggingFace's own internal kwargs.
  • Enhanced build_packed_tree_batch: Moved the BLOCK_SIZE alignment validation into the build_packed_tree_batch function and made it more robust by considering additional parallel_size dimensions using math.lcm for stricter alignment requirements.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/engine/fsdp_engine.py
    • Removed direct imports of tree attention constants and builders, replacing them with build_tree_attn_kwargs.
    • Updated FSDPTrainContext.to_dict docstring for clarity on shallow conversion.
    • Refactored forward_backward_batch to use build_tree_attn_kwargs for preparing tree attention metadata and dynamically deleting keys.
    • Removed BLOCK_SIZE assertion and passed parallel_size to build_packed_tree_batch.
  • areal/engine/megatron_engine.py
    • Removed direct imports of tree attention constants and builders, replacing them with build_tree_attn_kwargs.
    • Refactored forward_step to use build_tree_attn_kwargs with dense_mask=True for Megatron's gradient checkpointing compatibility.
    • Removed BLOCK_SIZE assertion and passed parallel_size to build_packed_tree_batch.
  • areal/experimental/engine/archon_engine.py
    • Imported dataclasses for ArchonTrainContext.
    • Removed BLOCK_SIZE import.
    • Modified ArchonTrainContext to use dataclasses.fields for to_dict and made labels optional.
    • Added a NotImplementedError check for tree training with critic models during initialization.
    • Refactored _prepare_mb_inputs to correctly set labels and trie_node based on tree training enablement.
    • Removed cp_enabled check and BLOCK_SIZE assertion from _prepare_mb_list, adding parallel_size to build_packed_tree_batch.
    • Replaced direct logic in _compute_logprobs_and_loss and _compute_forward_result with calls to new helper methods: _gather_actor_train_outputs, _gather_actor_forward_output, and _gather_critic_output.
  • areal/experimental/engine/archon_runner.py
    • Removed direct imports of tree attention constants and builders, replacing them with TreeAttnMeta.
    • Refactored run method to use TreeAttnMeta.from_trie for creating tree attention metadata.
    • Updated calls to ctx.to_dict() to use the new method.
  • areal/experimental/models/archon/attention.py
    • Imported TreeAttnMeta and added it to __all__.
    • Modified SDPAWrapper.forward signature to accept tree_attn_meta instead of generic **kwargs.
  • areal/experimental/models/archon/base.py
    • Imported TreeAttnMeta for type hinting.
    • Updated ArchonModelBase.forward abstract method signature to include tree_attn_meta.
  • areal/experimental/models/archon/qwen2/model/model.py
    • Removed imports of BlockMask and TreeAttentionData.
    • Imported TreeAttnMeta.
    • Updated Qwen2Attention.forward and Qwen2DecoderLayer.forward signatures to use tree_attn_meta.
  • areal/experimental/models/archon/qwen3/model/model.py
    • Removed imports of BlockMask and TreeAttentionData.
    • Imported TreeAttnMeta.
    • Updated Qwen3Attention.forward and Qwen3DecoderLayer.forward signatures to use tree_attn_meta.
  • areal/experimental/models/archon/ulysses.py
    • Updated type hint for cp_group in ulysses_gather_output to allow None.
  • areal/experimental/models/archon/varlen_attention.py
    • Imported TreeAttnMeta.
    • Modified VarlenAttentionWrapper.forward signature to accept tree_attn_meta for interface compatibility.
  • areal/models/tree_attn/module.py
    • Removed BLOCK_SIZE from constants import.
    • Added build_tree_attn_kwargs and TreeAttnMeta to exports in __all__.
  • areal/models/tree_attn/module_archon.py
    • Added TreeAttnMeta dataclass to encapsulate tree attention metadata.
    • Updated TreeAttentionWrapper to use TreeAttnMeta as its primary input for attention metadata.
    • Centralized backend selection logic within TreeAttnMeta.from_trie.
  • areal/models/tree_attn/module_fsdp.py
    • Modified _tree_attn_fwd_func to expect tree attention kwargs prefixed with tree_ (e.g., tree_triton_data, tree_block_mask) to avoid collisions.
  • areal/models/tree_attn/tree.py
    • Imported math, USE_TRITON_TREE_ATTN, and TRITON_AVAILABLE.
    • Added parallel_size parameter to build_packed_tree_batch and updated alignment validation using math.lcm.
    • Introduced build_tree_attn_kwargs function to create backend-specific tree attention parameters.
  • areal/utils/logging.py
    • Added TreeAttentionWrapper to the LOG_COLORS dictionary for colored logging output.
  • areal/utils/mcore/packed_context_parallel.py
    • Updated comments and variable names to reflect the tree_ prefixing for Triton tree attention data.
Activity
  • The pull request includes code refactoring with no functional changes.
  • Unit tests require GPU and were not run on the machine where this PR was authored.
  • No critical issues were raised by AI reviewers.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a significant and well-executed refactoring of the tree attention training logic. The changes successfully simplify the codebase by encapsulating backend selection, unifying the parameter interface across different engines with tree_attn_meta, and improving the overall structure by extracting focused helper methods. Moving the BLOCK_SIZE alignment validation into build_packed_tree_batch is also a good change that centralizes responsibility. The code is now cleaner and more maintainable. I have a couple of medium-severity suggestions regarding the change from ValueError to assert for critical checks, which could be less robust in production environments where assertions might be disabled.

@rchardx rchardx changed the title refactor(tree-attn): flatten tree training branching and encapsulate backend selection refactor(tree-attn): simplify tree attention plumbing and restructure attention package Feb 10, 2026
@rchardx rchardx force-pushed the rchardx/cleanup branch 2 times, most recently from 3b77aaf to 41b2725 Compare February 10, 2026 12:10
@rchardx rchardx added the safe-to-test Ready to run unit-tests in a PR. label Feb 10, 2026
… attention package

Consolidate tree attention metadata into TreeAttentionMeta dataclass,
pass it end-to-end through model layers, and restructure the Archon
attention module into a package with clear separation of concerns.

Key changes:
- Add TreeAttentionMeta with from_trie() factory to auto-select backend
- Pass tree_attn_meta through Model → TransformerBlock → Attention
  in Archon models (qwen2, qwen3)
- Convert attention.py into attention/ package (sdpa.py, varlen.py)
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
  _gather_critic_output helpers in ArchonEngine
- Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding
  with tree_ prefix to avoid HF kwarg collisions
- Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel
  (caused InductorError with torch.compile backward)
- Fix missing vocab_min/vocab_max CP gather in Archon engine
- Move BLOCK_SIZE alignment into build_packed_tree_batch via
  parallel_size parameter with math.lcm

Refs: #912
@rchardx rchardx added safe-to-test Ready to run unit-tests in a PR. and removed safe-to-test Ready to run unit-tests in a PR. labels Feb 10, 2026
@rchardx rchardx deployed to AReaL-unittests February 10, 2026 12:49 — with GitHub Actions Active
Copy link
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Waiting for CI.

@garrett4wade garrett4wade merged commit 5c300be into main Feb 10, 2026
8 checks passed
@garrett4wade garrett4wade deleted the rchardx/cleanup branch February 10, 2026 14:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Ready to run unit-tests in a PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants