Commit 41b2725
committed
refactor(tree-attn): simplify tree attention plumbing and restructure attention package
Consolidate tree attention metadata into TreeAttentionMeta dataclass,
pass it end-to-end through model layers, and restructure the Archon
attention module into a package with clear separation of concerns.
Key changes:
- Add TreeAttentionMeta with from_trie() factory to auto-select backend
- Pass tree_attn_meta through Model → TransformerBlock → Attention
in Archon models (qwen2, qwen3)
- Convert attention.py into attention/ package (sdpa.py, varlen.py)
- Extract _gather_actor_train_outputs, _gather_actor_forward_output,
_gather_critic_output helpers in ArchonEngine
- Add build_tree_attn_kwargs for FSDP/Megatron dict-based forwarding
with tree_ prefix to avoid HF kwarg collisions
- Fix unused stride_ob in triton _tree_attn_bwd_preprocess kernel
(caused InductorError with torch.compile backward)
- Fix missing vocab_min/vocab_max CP gather in Archon engine
- Move BLOCK_SIZE alignment into build_packed_tree_batch via
parallel_size parameter with math.lcm
Refs: #9121 parent 7b20898 commit 41b2725
18 files changed
Lines changed: 378 additions & 396 deletions
File tree
- areal
- engine
- experimental
- engine
- models/archon
- attention
- qwen2/model
- qwen3/model
- models/tree_attn
- utils
- mcore
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
| 76 | + | |
81 | 77 | | |
82 | 78 | | |
83 | 79 | | |
| |||
150 | 146 | | |
151 | 147 | | |
152 | 148 | | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
158 | | - | |
159 | | - | |
| 149 | + | |
| 150 | + | |
160 | 151 | | |
161 | 152 | | |
162 | 153 | | |
| |||
548 | 539 | | |
549 | 540 | | |
550 | 541 | | |
551 | | - | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
552 | 548 | | |
553 | 549 | | |
554 | | - | |
555 | | - | |
556 | | - | |
557 | | - | |
558 | | - | |
559 | | - | |
560 | | - | |
561 | | - | |
562 | | - | |
563 | | - | |
564 | | - | |
565 | | - | |
566 | | - | |
567 | | - | |
568 | | - | |
569 | | - | |
570 | | - | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
571 | 556 | | |
572 | 557 | | |
573 | 558 | | |
574 | 559 | | |
575 | 560 | | |
576 | 561 | | |
577 | | - | |
578 | | - | |
579 | | - | |
580 | | - | |
581 | | - | |
| 562 | + | |
| 563 | + | |
582 | 564 | | |
583 | 565 | | |
584 | 566 | | |
| |||
1269 | 1251 | | |
1270 | 1252 | | |
1271 | 1253 | | |
1272 | | - | |
1273 | | - | |
1274 | | - | |
1275 | | - | |
1276 | | - | |
1277 | | - | |
1278 | 1254 | | |
1279 | 1255 | | |
1280 | 1256 | | |
1281 | 1257 | | |
1282 | 1258 | | |
| 1259 | + | |
| 1260 | + | |
1283 | 1261 | | |
1284 | 1262 | | |
1285 | 1263 | | |
| |||
1397 | 1375 | | |
1398 | 1376 | | |
1399 | 1377 | | |
| 1378 | + | |
1400 | 1379 | | |
1401 | 1380 | | |
1402 | 1381 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
| 63 | + | |
68 | 64 | | |
69 | 65 | | |
70 | 66 | | |
| |||
573 | 569 | | |
574 | 570 | | |
575 | 571 | | |
576 | | - | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
577 | 577 | | |
578 | 578 | | |
579 | 579 | | |
580 | 580 | | |
581 | 581 | | |
582 | 582 | | |
583 | | - | |
584 | | - | |
585 | | - | |
586 | | - | |
587 | | - | |
588 | | - | |
589 | | - | |
590 | | - | |
591 | | - | |
592 | | - | |
593 | | - | |
594 | | - | |
595 | | - | |
596 | | - | |
597 | | - | |
598 | | - | |
599 | | - | |
600 | | - | |
601 | | - | |
602 | | - | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
603 | 593 | | |
604 | 594 | | |
605 | 595 | | |
606 | 596 | | |
607 | | - | |
608 | | - | |
609 | | - | |
610 | | - | |
611 | | - | |
| 597 | + | |
| 598 | + | |
612 | 599 | | |
613 | 600 | | |
614 | 601 | | |
| |||
1383 | 1370 | | |
1384 | 1371 | | |
1385 | 1372 | | |
1386 | | - | |
1387 | | - | |
1388 | | - | |
1389 | | - | |
1390 | 1373 | | |
1391 | 1374 | | |
1392 | 1375 | | |
1393 | 1376 | | |
| 1377 | + | |
| 1378 | + | |
1394 | 1379 | | |
1395 | 1380 | | |
1396 | 1381 | | |
| |||
0 commit comments