Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .claude/agents/archon-engine-expert.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ patterns (refer to source code).
constraints
- `areal/experimental/models/archon/activation_checkpoint.py` - Activation checkpointing
- `areal/experimental/models/archon/compile.py` - torch.compile integration
- `areal/experimental/models/archon/varlen_attention.py` - Variable-length attention
- `areal/experimental/models/archon/attention.py` - Attention mechanism implementations
- `areal/experimental/models/archon/attention/` - Attention package
- `attention/sdpa.py` - Scaled dot-product attention
- `attention/varlen.py` - Variable-length attention (custom op registration)

## Resources

Expand Down
2 changes: 1 addition & 1 deletion .claude/data/pr-review-change-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ______________________________________________________________________
| **REWARD** | `areal/reward/` | `reward_fn`, `AsyncRewardWrapper`, `MathVerifyWorker` |
| **DATASET** | `areal/dataset/` | `get_*_dataset`, `DataLoader`, `IterableDataset` |
| **LAUNCHER_SCHEDULER** | `areal/infra/launcher/`, `areal/infra/scheduler/`, `areal/infra/rpc/` | `LaunchConfig`, `Scheduler`, `RayLauncher`, `SlurmLauncher` |
| **ATTENTION** | `attention.py`, `varlen_attention.py` | `flash_attn`, `sdpa`, `varlen`, `causal_mask` |
| **ATTENTION** | `attention/`, `attention/sdpa.py`, `attention/varlen.py` | `flash_attn`, `sdpa`, `varlen`, `causal_mask` |

## LOW Level (Use Haiku)

Expand Down
61 changes: 20 additions & 41 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@
merge_packed_tree_results,
)
from areal.models.tree_attn.module import (
BLOCK_SIZE,
TRITON_AVAILABLE,
USE_TRITON_TREE_ATTN,
build_block_mask_from_trie,
build_triton_attn_data_from_trie,
build_tree_attn_kwargs,
patch_fsdp_for_tree_training,
)
from areal.models.tree_attn.tree import TrieNode, build_packed_tree_batch
Expand Down Expand Up @@ -150,13 +146,8 @@ class FSDPTrainContext:
trie_node: TrieNode | None = None

def to_dict(self) -> dict[str, Any]:
"""Convert to dict without recursive serialization of trie_node.
Note: We cannot use dataclasses.asdict() here because it recursively
converts all nested objects. The trie_node field contains a TrieNode
with recursive parent/child references, which causes
"RecursionError: maximum recursion depth exceeded" when asdict()
attempts to serialize the entire tree structure.
"""Shallow dict conversion (avoids ``dataclasses.asdict`` which would
recurse into TrieNode and hit ``RecursionError``).
"""
return {f.name: getattr(self, f.name) for f in dataclasses.fields(self)}

Expand Down Expand Up @@ -548,37 +539,28 @@ def forward_backward_batch(
for mb_item in mb_list:
inputs, ctx = self._prepare_mb_inputs(mb_item)

# Lazily create tree attention metadata just before forward
# Lazily create tree attention metadata just before forward.
# The returned dict keys are prefixed with "tree_" to avoid collisions
# with HuggingFace's own kwargs. The patched _tree_attn_fwd_func in
# module_fsdp.py reads these keys from the **kwargs that transformers
# forwards through.
tree_attn_keys: list[str] = []
if self.enable_tree_training and ctx.trie_node is not None:
padded_size = mb_item.padded_to_length
if padded_size is None:
raise ValueError(
"padded_size must be set for tree training with FSDP."
)
if USE_TRITON_TREE_ATTN and TRITON_AVAILABLE:
triton_attn_data = build_triton_attn_data_from_trie(
ctx.trie_node, padded_size
)
inputs["triton_attn_data"] = triton_attn_data
else:
block_mask = build_block_mask_from_trie(
ctx.trie_node, padded_size, self.device
)
# Pass block_mask as a separate kwarg, not as attention_mask.
# The patched _tree_attn_fwd_func expects block_mask in kwargs,
# which transformers will pass through to the attention function.
inputs["block_mask"] = block_mask
assert padded_size is not None
tree_kwargs = build_tree_attn_kwargs(
ctx.trie_node, padded_size, self.device
)
inputs.update(tree_kwargs)
tree_attn_keys = list(tree_kwargs.keys())

with trace_scope("fsdp_engine.forward"):
outputs = self.model(**inputs)
logits = outputs.logits.squeeze(0)

# Release tree attention metadata after forward pass
if self.enable_tree_training:
if "block_mask" in inputs:
del inputs["block_mask"]
if "triton_attn_data" in inputs:
del inputs["triton_attn_data"]
for key in tree_attn_keys:
del inputs[key]

ctx_dict = ctx.to_dict()
loss = process_output_fn(logits, ctx_dict)
Expand Down Expand Up @@ -1269,17 +1251,13 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:

# Tree training path
if self.enable_tree_training:
sp_size = self.parallel_helper.sp_size
tp_size = self.parallel_helper.tp_size
# Build tree inputs
assert BLOCK_SIZE % (tp_size * sp_size) == 0, (
f"BLOCK_SIZE ({BLOCK_SIZE}) must be divisible by the product of tensor and sequence parallel sizes ({tp_size * sp_size})."
)
mb_list = build_packed_tree_batch(
input_,
mb_spec=self.config.mb_spec,
pad_to_maximum=self.config.pad_to_maximum,
dp_group=self.data_parallel_group,
parallel_size=self.parallel_helper.tp_size
* self.parallel_helper.sp_size,
)
self.logger.info(
f"Packed tree #microbatch: {len(mb_list)}, microbatch #tokens: {mb_list.group_lens}, "
Expand Down Expand Up @@ -1397,6 +1375,7 @@ def _prepare_mb_inputs(
This method handles Ulysses SP padding and slicing, returning both
the prepared model inputs and a context object for later processing.
"""
trie_node = None
if self.parallel_helper.sp_size > 1:
input_ids = mb_item.padded_mb["input_ids"]
position_ids = mb_item.padded_mb.get("position_ids", None)
Expand Down
55 changes: 20 additions & 35 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@
merge_packed_tree_results,
)
from areal.models.tree_attn.module import (
BLOCK_SIZE,
TRITON_AVAILABLE,
USE_TRITON_TREE_ATTN,
build_attention_mask_from_trie,
build_triton_attn_data_from_trie,
build_tree_attn_kwargs,
patch_bridge_for_tree_training,
)
from areal.models.tree_attn.tree import build_packed_tree_batch
Expand Down Expand Up @@ -573,42 +569,33 @@ def forward_step(batch_iter, model):

cu_seqlens = mb_input.padded_mb.get("cu_seqlens", None)

# Lazily create tree attention metadata just before forward
# Lazily create tree attention metadata just before forward.
# dense_mask=True because Megatron's gradient checkpointing uses
# save_for_backward() which can only save torch.Tensor objects;
# BlockMask is recreated inside PytorchFlexAttention.forward().
tree_attn_keys: list[str] = []
if self.enable_tree_training:
trie_node = mb_input.padded_mb.get("trie_node", None)
# Ensure trie_node is also in orig_mb for _compute_logprobs_and_loss
if trie_node is not None and "trie_node" not in mb_input.orig_mb:
mb_input.orig_mb["trie_node"] = trie_node
padded_size = mb_input.padded_to_length
if trie_node is not None and padded_size is not None:
if USE_TRITON_TREE_ATTN and TRITON_AVAILABLE:
triton_attn_data = build_triton_attn_data_from_trie(
trie_node, padded_size
)
mb_input.padded_mb["triton_attn_data"] = triton_attn_data
else:
# FIX: Use dense attention mask tensor instead of BlockMask.
# Megatron's gradient checkpointing (tensor_parallel.checkpoint)
# uses save_for_backward() which can only save torch.Tensor objects.
# BlockMask is a custom data structure that cannot be serialized.
# By passing a dense tensor, the checkpoint mechanism can save it,
# and the BlockMask will be created inside PytorchFlexAttention.forward()
# during both forward and recompute (backward) passes.
attention_mask = build_attention_mask_from_trie(
trie_node,
padded_size,
mb_input.padded_mb["input_ids"].device,
)
mb_input.padded_mb["attention_mask"] = attention_mask
if trie_node is not None:
assert padded_size is not None
tree_kwargs = build_tree_attn_kwargs(
trie_node,
padded_size,
mb_input.padded_mb["input_ids"].device,
dense_mask=True,
)
mb_input.padded_mb.update(tree_kwargs)
tree_attn_keys = list(tree_kwargs.keys())

output = packed_context_parallel_forward(model, mb_input.padded_mb)

# Release tree attention metadata after forward pass
if self.enable_tree_training:
if "attention_mask" in mb_input.padded_mb:
del mb_input.padded_mb["attention_mask"]
if "triton_attn_data" in mb_input.padded_mb:
del mb_input.padded_mb["triton_attn_data"]
for key in tree_attn_keys:
del mb_input.padded_mb[key]

def _process_output(input_, output_):
loss = process_output_fn(output_, input_)
Expand Down Expand Up @@ -1383,14 +1370,12 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
assert cp_size == 1, (
"Context parallelism is not supported in tree training."
)
# Build tree inputs
assert BLOCK_SIZE % tp_size == 0, (
f"BLOCK_SIZE ({BLOCK_SIZE}) must be divisible by tensor parallel size ({tp_size})."
)
mb_list = build_packed_tree_batch(
input_,
mb_spec=self.config.mb_spec,
pad_to_maximum=self.config.pad_to_maximum,
dp_group=self.data_parallel_group,
parallel_size=tp_size,
)
recommended_min_n_mbs = 2 * pp_size if pp_size > 1 else 1
self.logger.info(
Expand Down
Loading