Draft
Conversation
… integration Adds a tree-aware SFT path (gated by `data.type == "caterpillar_fake"`) that packs caterpillar-shaped trajectories into a single sequence and computes per-token loss weighted by g_t/K to match per-branch training. Standalone unit tests assert FP64 gradient equivalence to the per-branch baseline; integration test runs end-to-end through the SFT trainer with HF dense Qwen3 + SDPA. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Side-by-side training script that runs Qwen3-0.6B with both tree mode and per-branch baseline on identical caterpillar data and compares loss curves + parameter drift over many steps. 200-step run shows max per-step loss rel diff of 1.2e-7 and bounded ~1e-4 param drift, validating the v1 implementation at real-model training scale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…blation
Adds the per-branch counterpart of CaterpillarFakeDataset that yields each
leaf's root-to-leaf path as a flat SFT sample, sharing the same caterpillar
fixture so tree-mode and baseline runs see identical synthetic data unrolled
differently. Adds two TOMLs (configs/tree_ablation/{tree,baseline}.toml)
plus loss/sum and loss/token_count log fields so wandb can render directly
comparable curves.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two related fixes that made the per-branch baseline non-comparable with the tree mode in the wandb ablation: - Caterpillar trees have num_turns + 1 leaves (one think branch per turn plus the final response with no children). The dataset hardcoded K = num_turns, silently dropping the final-response leaf and visiting each tree K times instead of K+1 times. - setup_dataloader's cat packing concatenated multiple branches per micro-batch, so batch_size in the config no longer mapped to branches-per-step. Bypass packing for caterpillar_per_branch, mirroring the existing caterpillar_fake behavior, so each grad-accum step processes exactly one branch. baseline.toml batch_size bumped from 32 to 40 (= 8 trees * 5 leaves) to match tree.toml's per-step workload. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds SFTCaterpillarDataset and SFTCaterpillarPerBranchDataset that wrap any HF SFT dataset whose assistant messages contain <think>...</think>. Each example becomes a 1-turn caterpillar: user prompt as the trunk node, <think>...</think> as a leaf, post-</think> content as a sibling leaf. Skips overlong samples (> seq_len). Ablation configs target PrimeIntellect/INTELLECT-3-SFT-10K math split with Qwen3-0.6B at seq_len=4096, 200 steps. Smoke tests show clean K=2 loss-sum ratio between tree and baseline modes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implements PLAN2.md: replace SDPA + materialized [N,N] mask with PyTorch FlexAttention's BlockMask predicate so the ancestor-only mask compiles to a fused Triton kernel that skips fully-False blocks (sibling regions) and fast-paths fully-True blocks (trunk attention). - Add `model.attn = "flex_attention"` to AttnImplementation literal; broaden the tree validator to accept it. - Extend PackedTree with node_of_token and is_ancestor_node helpers. - New flex_mask.py: tree_mask_mod predicate + build_tree_block_mask(packed) constructor. - Thread helpers through TreeSample/Batch/cat_collate and SFT compute_loss; build the BlockMask once per tree shape (cached). - SDPA tree path stays the canonical reference — bit-identical to v1. - New tests/unit/train/tree/test_tree_flex.py: flex-vs-sdpa correctness regression. - New tests/perf/test_tree_attention_speedup.py: speedup guard tests. - New configs/tree_ablation/realistic_tree_flex.toml: realistic ablation config. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous tests/unit/train/tree/_tiny_model.py called flex_attention(...) directly which falls back to PyTorch's unfused, materialized-scores path. That path is slower than SDPA's math backend due to extra dispatch overhead, and made the perf regression tests fail (3.3× slower than SDPA, 2.2× slower than per-branch baseline). Wrap flex_attention with torch.compile once at module load and route BlockMask calls through it. Production HF Qwen3 already uses transformers' compiled wrapper — this only affected the perf test harness. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Compiled FlexAttention (introduced in 72475ed) requires head_dim >= 16, while the previous fixture used hidden_size=64 with num_heads=8 → head_dim=8, which is fine in the eager FlexAttention path but raises an InductorError once torch.compile is wrapped around flex_attention. Reduce num_heads to 4 to get head_dim=16 with the same hidden size; same parameter count, same correctness scope. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds the Tree Training v1 SFT path and the raw-tool ablation controls needed to test it on realistic multi-turn trajectories.
Core pieces in the branch:
prev_map, and weighted tree NLL.BlockMaskbatches.PrimeIntellect/INTELLECT-5-SFT-Raw, includingreasoning_contentas reasoning side leaves and visible assistant/tool content on the trunk.reasoning_contentplus visible assistant output in the completion, and extension only when the next prompt still extends the previous full sample.seq_lennow caps root-to-leaf branch paths, whilemax_packed_tokensseparately caps the DFS-packed tree. This lets high-branching trees through instead of rejecting them because the whole packed tree exceeded the branch context length.selection_metric = "branching_score"path usingdatasets.map(..., num_proc=...)to prefer examples with actual reasoning side branches rather than sorting only bynum_turns.INTELLECT-5-SFT-Raw/general_agent_rlm, including tree, per-branch, grouped-equivalence, and current-RL baselines.Why This Matters
The initial realistic experiments were biased toward weak branching because raw-tool tree construction rejected any example with
packed_len > seq_len. That defeats the point of tree training: the interesting examples are often those where every branch fits the model context, but the full DFS-packed tree is larger because it contains many branches.The new behavior is:
seq_len: maximum root-to-leaf path length, equivalent to normal model context.max_packed_tokens: maximum full packed tree length, bounded by memory/runtime.For the high-branching config added here:
The selected examples are much stronger than the earlier
rlm_math/rlm_sciencetop-turn runs:Subset scan summary for exact-tokenized top candidates, with
path <= 8192andpacked <= 65536:Repro: Unit/Correctness Checks
Run from repo root:
Observed locally before the latest current-RL baseline commit:
Latest current-RL baseline commit validation:
uv run ruff check \ src/prime_rl/configs/sft.py \ src/prime_rl/trainer/sft/data.py \ tests/unit/train/tree/test_tree_pack.py uv run pytest tests/unit/train/tree/test_tree_pack.py -q uv run python -c "import tomllib; from prime_rl.configs.sft import SFTConfig; data = tomllib.load(open('configs/tree_ablation/intellect5_general_agent_current_rl_baseline.toml', 'rb')); cfg = SFTConfig.model_validate(data); print(cfg.data.type, cfg.data.subset, cfg.data.selection_metric)"Observed after the latest current-RL baseline commit:
GPU correctness/perf gates from the earlier Tree Training v1.1 work:
UV_NO_SYNC=1 uv run pytest tests/unit/train/tree/test_tree_flex.py -q -m gpu UV_NO_SYNC=1 uv run pytest tests/perf/test_tree_attention_speedup.py -v -m "gpu and perf"Repro: 2-GPU High-Branching Ablation
Hardware used for the run below:
Note:
--deployment.num-gpusis the SFT launcher GPU-rank count. It is not an inference deployment flag.Tree FlexAttention:
Per-branch baseline:
Observed local results:
Do not use the logged
perf/throughputas the headline metric for this comparison. It assumesbatch_size * seq_len, which is not the right logical token count for variable tree-vs-branch workloads. Wall-clock step time and branch coverage are the useful metrics here.The per-branch baseline also covers fewer logical tree branches per optimizer step. The selected trees have mean
K = 20.42leaves. With the local configs, tree-flex trains on 4 full trees per optimizer step, while the per-branch baseline trains on 16 individual branches per optimizer step. Normalized by logical branch coverage, the tree path is substantially stronger than the raw 3.75x step-time comparison.Repro: 8xH200 Cluster
The current configs were written for the 2-GPU local run. For 8 GPUs,
batch_sizemust be divisible byworld_size * micro_batch_size, so override batch sizes on the command line.Tree FlexAttention on 8 GPUs:
Current-RL flat-sample baseline on 8 GPUs:
Per-branch baseline on 8 GPUs:
Grouped branch-equivalence baseline on 8 GPUs:
If grouped equivalence OOMs, retry with activation checkpointing:
Grouped Equivalence Baseline Caveat
The grouped baseline is the mathematically clean per-tree baseline: for one tree with
Kleaves, it forwards every root-to-leaf branch and combines branch losses with weight1/Kbefore the optimizer step.This is useful for correctness, but the current implementation is not a practical high-branching perf baseline. It accumulates all branch autograd graphs before a single backward. On the local 2x96 GiB machine:
--model.ac: still did not finish step 0 after several minutes, so it was killedFor practical apples-to-apples runs against the current production behavior, use
sft_raw_tool_current_rl_baseline; it follows the current RL sample construction instead of the tree-equivalence branch set.Validation Performed
Latest current-RL baseline commit:
uv run ruff check \ src/prime_rl/configs/sft.py \ src/prime_rl/trainer/sft/data.py \ tests/unit/train/tree/test_tree_pack.py uv run pytest tests/unit/train/tree/test_tree_pack.py -q uv run python -c "import tomllib; from prime_rl.configs.sft import SFTConfig; data = tomllib.load(open('configs/tree_ablation/intellect5_general_agent_current_rl_baseline.toml', 'rb')); cfg = SFTConfig.model_validate(data); print(cfg.data.type, cfg.data.subset, cfg.data.selection_metric)"GPU runs performed locally:
Known unrelated warning seen during SFT dry-runs and launches:
This warning predates this PR path and did not block config parsing or training.