Skip to content

Commit 41b2725

Browse files
committed
refactor(tree-attn): simplify tree attention plumbing and restructure attention package
Consolidate tree attention metadata into TreeAttentionMeta dataclass, pass it end-to-end through model layers, and restructure the Archon attention module into a package with clear separation of concerns. Key changes: - Add TreeAttentionMeta with from_trie() factory to auto-select backend - Pass tree_attn_meta through Model → TransformerBlock → Attention in Archon models (qwen2, qwen3) - Convert attention.py into attention/ package (sdpa.py, varlen.py) - Extract _gather_actor_train_outputs, _gather_actor_forward_output, _gather_critic_output helpers in ArchonEngine - Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding with tree_ prefix to avoid HF kwarg collisions - Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel (caused InductorError with torch.compile backward) - Fix missing vocab_min/vocab_max CP gather in Archon engine - Move BLOCK_SIZE alignment into build_packed_tree_batch via parallel_size parameter with math.lcm Refs: #912
1 parent 7b20898 commit 41b2725

18 files changed

Lines changed: 378 additions & 396 deletions

File tree

areal/engine/fsdp_engine.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,7 @@
7373
merge_packed_tree_results,
7474
)
7575
from areal.models.tree_attn.module import (
76-
BLOCK_SIZE,
77-
TRITON_AVAILABLE,
78-
USE_TRITON_TREE_ATTN,
79-
build_block_mask_from_trie,
80-
build_triton_attn_data_from_trie,
76+
build_tree_attn_kwargs,
8177
patch_fsdp_for_tree_training,
8278
)
8379
from areal.models.tree_attn.tree import TrieNode, build_packed_tree_batch
@@ -150,13 +146,8 @@ class FSDPTrainContext:
150146
trie_node: TrieNode | None = None
151147

152148
def to_dict(self) -> dict[str, Any]:
153-
"""Convert to dict without recursive serialization of trie_node.
154-
155-
Note: We cannot use dataclasses.asdict() here because it recursively
156-
converts all nested objects. The trie_node field contains a TrieNode
157-
with recursive parent/child references, which causes
158-
"RecursionError: maximum recursion depth exceeded" when asdict()
159-
attempts to serialize the entire tree structure.
149+
"""Shallow dict conversion (avoids ``dataclasses.asdict`` which would
150+
recurse into TrieNode and hit ``RecursionError``).
160151
"""
161152
return {f.name: getattr(self, f.name) for f in dataclasses.fields(self)}
162153

@@ -548,37 +539,28 @@ def forward_backward_batch(
548539
for mb_item in mb_list:
549540
inputs, ctx = self._prepare_mb_inputs(mb_item)
550541

551-
# Lazily create tree attention metadata just before forward
542+
# Lazily create tree attention metadata just before forward.
543+
# The returned dict keys are prefixed with "tree_" to avoid collisions
544+
# with HuggingFace's own kwargs. The patched _tree_attn_fwd_func in
545+
# module_fsdp.py reads these keys from the **kwargs that transformers
546+
# forwards through.
547+
tree_attn_keys: list[str] = []
552548
if self.enable_tree_training and ctx.trie_node is not None:
553549
padded_size = mb_item.padded_to_length
554-
if padded_size is None:
555-
raise ValueError(
556-
"padded_size must be set for tree training with FSDP."
557-
)
558-
if USE_TRITON_TREE_ATTN and TRITON_AVAILABLE:
559-
triton_attn_data = build_triton_attn_data_from_trie(
560-
ctx.trie_node, padded_size
561-
)
562-
inputs["triton_attn_data"] = triton_attn_data
563-
else:
564-
block_mask = build_block_mask_from_trie(
565-
ctx.trie_node, padded_size, self.device
566-
)
567-
# Pass block_mask as a separate kwarg, not as attention_mask.
568-
# The patched _tree_attn_fwd_func expects block_mask in kwargs,
569-
# which transformers will pass through to the attention function.
570-
inputs["block_mask"] = block_mask
550+
assert padded_size is not None
551+
tree_kwargs = build_tree_attn_kwargs(
552+
ctx.trie_node, padded_size, self.device
553+
)
554+
inputs.update(tree_kwargs)
555+
tree_attn_keys = list(tree_kwargs.keys())
571556

572557
with trace_scope("fsdp_engine.forward"):
573558
outputs = self.model(**inputs)
574559
logits = outputs.logits.squeeze(0)
575560

576561
# Release tree attention metadata after forward pass
577-
if self.enable_tree_training:
578-
if "block_mask" in inputs:
579-
del inputs["block_mask"]
580-
if "triton_attn_data" in inputs:
581-
del inputs["triton_attn_data"]
562+
for key in tree_attn_keys:
563+
del inputs[key]
582564

583565
ctx_dict = ctx.to_dict()
584566
loss = process_output_fn(logits, ctx_dict)
@@ -1269,17 +1251,13 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
12691251

12701252
# Tree training path
12711253
if self.enable_tree_training:
1272-
sp_size = self.parallel_helper.sp_size
1273-
tp_size = self.parallel_helper.tp_size
1274-
# Build tree inputs
1275-
assert BLOCK_SIZE % (tp_size * sp_size) == 0, (
1276-
f"BLOCK_SIZE ({BLOCK_SIZE}) must be divisible by the product of tensor and sequence parallel sizes ({tp_size * sp_size})."
1277-
)
12781254
mb_list = build_packed_tree_batch(
12791255
input_,
12801256
mb_spec=self.config.mb_spec,
12811257
pad_to_maximum=self.config.pad_to_maximum,
12821258
dp_group=self.data_parallel_group,
1259+
parallel_size=self.parallel_helper.tp_size
1260+
* self.parallel_helper.sp_size,
12831261
)
12841262
self.logger.info(
12851263
f"Packed tree #microbatch: {len(mb_list)}, microbatch #tokens: {mb_list.group_lens}, "
@@ -1397,6 +1375,7 @@ def _prepare_mb_inputs(
13971375
This method handles Ulysses SP padding and slicing, returning both
13981376
the prepared model inputs and a context object for later processing.
13991377
"""
1378+
trie_node = None
14001379
if self.parallel_helper.sp_size > 1:
14011380
input_ids = mb_item.padded_mb["input_ids"]
14021381
position_ids = mb_item.padded_mb.get("position_ids", None)

areal/engine/megatron_engine.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,7 @@
6060
merge_packed_tree_results,
6161
)
6262
from areal.models.tree_attn.module import (
63-
BLOCK_SIZE,
64-
TRITON_AVAILABLE,
65-
USE_TRITON_TREE_ATTN,
66-
build_attention_mask_from_trie,
67-
build_triton_attn_data_from_trie,
63+
build_tree_attn_kwargs,
6864
patch_bridge_for_tree_training,
6965
)
7066
from areal.models.tree_attn.tree import build_packed_tree_batch
@@ -573,42 +569,33 @@ def forward_step(batch_iter, model):
573569

574570
cu_seqlens = mb_input.padded_mb.get("cu_seqlens", None)
575571

576-
# Lazily create tree attention metadata just before forward
572+
# Lazily create tree attention metadata just before forward.
573+
# dense_mask=True because Megatron's gradient checkpointing uses
574+
# save_for_backward() which can only save torch.Tensor objects;
575+
# BlockMask is recreated inside PytorchFlexAttention.forward().
576+
tree_attn_keys: list[str] = []
577577
if self.enable_tree_training:
578578
trie_node = mb_input.padded_mb.get("trie_node", None)
579579
# Ensure trie_node is also in orig_mb for _compute_logprobs_and_loss
580580
if trie_node is not None and "trie_node" not in mb_input.orig_mb:
581581
mb_input.orig_mb["trie_node"] = trie_node
582582
padded_size = mb_input.padded_to_length
583-
if trie_node is not None and padded_size is not None:
584-
if USE_TRITON_TREE_ATTN and TRITON_AVAILABLE:
585-
triton_attn_data = build_triton_attn_data_from_trie(
586-
trie_node, padded_size
587-
)
588-
mb_input.padded_mb["triton_attn_data"] = triton_attn_data
589-
else:
590-
# FIX: Use dense attention mask tensor instead of BlockMask.
591-
# Megatron's gradient checkpointing (tensor_parallel.checkpoint)
592-
# uses save_for_backward() which can only save torch.Tensor objects.
593-
# BlockMask is a custom data structure that cannot be serialized.
594-
# By passing a dense tensor, the checkpoint mechanism can save it,
595-
# and the BlockMask will be created inside PytorchFlexAttention.forward()
596-
# during both forward and recompute (backward) passes.
597-
attention_mask = build_attention_mask_from_trie(
598-
trie_node,
599-
padded_size,
600-
mb_input.padded_mb["input_ids"].device,
601-
)
602-
mb_input.padded_mb["attention_mask"] = attention_mask
583+
if trie_node is not None:
584+
assert padded_size is not None
585+
tree_kwargs = build_tree_attn_kwargs(
586+
trie_node,
587+
padded_size,
588+
mb_input.padded_mb["input_ids"].device,
589+
dense_mask=True,
590+
)
591+
mb_input.padded_mb.update(tree_kwargs)
592+
tree_attn_keys = list(tree_kwargs.keys())
603593

604594
output = packed_context_parallel_forward(model, mb_input.padded_mb)
605595

606596
# Release tree attention metadata after forward pass
607-
if self.enable_tree_training:
608-
if "attention_mask" in mb_input.padded_mb:
609-
del mb_input.padded_mb["attention_mask"]
610-
if "triton_attn_data" in mb_input.padded_mb:
611-
del mb_input.padded_mb["triton_attn_data"]
597+
for key in tree_attn_keys:
598+
del mb_input.padded_mb[key]
612599

613600
def _process_output(input_, output_):
614601
loss = process_output_fn(output_, input_)
@@ -1383,14 +1370,12 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
13831370
assert cp_size == 1, (
13841371
"Context parallelism is not supported in tree training."
13851372
)
1386-
# Build tree inputs
1387-
assert BLOCK_SIZE % tp_size == 0, (
1388-
f"BLOCK_SIZE ({BLOCK_SIZE}) must be divisible by tensor parallel size ({tp_size})."
1389-
)
13901373
mb_list = build_packed_tree_batch(
13911374
input_,
13921375
mb_spec=self.config.mb_spec,
13931376
pad_to_maximum=self.config.pad_to_maximum,
1377+
dp_group=self.data_parallel_group,
1378+
parallel_size=tp_size,
13941379
)
13951380
recommended_min_n_mbs = 2 * pp_size if pp_size > 1 else 1
13961381
self.logger.info(

0 commit comments

Comments
 (0)