Skip to content

Tree Training v1 raw-tool SFT ablations#2405

Draft
rasdani wants to merge 13 commits intomainfrom
feat/tree-training-v1
Draft

Tree Training v1 raw-tool SFT ablations#2405
rasdani wants to merge 13 commits intomainfrom
feat/tree-training-v1

Conversation

@rasdani
Copy link
Copy Markdown
Contributor

@rasdani rasdani commented May 3, 2026

Summary

This PR adds the Tree Training v1 SFT path and the raw-tool ablation controls needed to test it on realistic multi-turn trajectories.

Core pieces in the branch:

  • Tree data model, DFS packing, ancestor-only masks, prev_map, and weighted tree NLL.
  • SFT trainer support for tree batches, grouped branch-equivalence batches, and FlexAttention BlockMask batches.
  • Raw tool trajectory datasets for PrimeIntellect/INTELLECT-5-SFT-Raw, including reasoning_content as reasoning side leaves and visible assistant/tool content on the trunk.
  • A current-RL raw-tool baseline that builds flat assistant-step samples the same way the RL path does after reasoning stripping: visible context in the prompt, reasoning_content plus visible assistant output in the completion, and extension only when the next prompt still extends the previous full sample.
  • A fix for the earlier ablation issue: seq_len now caps root-to-leaf branch paths, while max_packed_tokens separately caps the DFS-packed tree. This lets high-branching trees through instead of rejecting them because the whole packed tree exceeded the branch context length.
  • A deterministic selection_metric = "branching_score" path using datasets.map(..., num_proc=...) to prefer examples with actual reasoning side branches rather than sorting only by num_turns.
  • Repro configs for the current high-branching ablation on INTELLECT-5-SFT-Raw/general_agent_rlm, including tree, per-branch, grouped-equivalence, and current-RL baselines.

Why This Matters

The initial realistic experiments were biased toward weak branching because raw-tool tree construction rejected any example with packed_len > seq_len. That defeats the point of tree training: the interesting examples are often those where every branch fits the model context, but the full DFS-packed tree is larger because it contains many branches.

The new behavior is:

  • seq_len: maximum root-to-leaf path length, equivalent to normal model context.
  • max_packed_tokens: maximum full packed tree length, bounded by memory/runtime.

For the high-branching config added here:

[data]
subset = "general_agent_rlm"
seq_len = 8192
max_packed_tokens = 32768
selection_metric = "branching_score"
selection_num_proc = 16
max_examples = 64

The selected examples are much stronger than the earlier rlm_math / rlm_science top-turn runs:

selected trees:             64 / 64 valid
mean K leaves:              20.42
max K leaves:               28
mean packed tokens:         8003.9
max packed tokens:          9126
mean max path tokens:       6614.9
max path tokens:            7743
mean branch/token ratio:    8.11x
max branch/token ratio:     12.91x
mean branch/pair ratio:     5.38x
max branch/pair ratio:      8.68x

Subset scan summary for exact-tokenized top candidates, with path <= 8192 and packed <= 65536:

subset                  top16 mean pair ratio   max pair ratio
----------------------------------------------------------------
general_agent_rlm       7.50                    8.68
kimina_rlm              4.60                    6.05
rlm_science             3.26                    4.12
rlm_math                3.19                    3.59
wikispeedia_links_only  2.92                    3.71

Repro: Unit/Correctness Checks

Run from repo root:

UV_NO_SYNC=1 uv run ruff check \
  src/prime_rl/configs/sft.py \
  src/prime_rl/trainer/sft/data.py \
  tests/unit/train/tree/test_tree_pack.py

UV_NO_SYNC=1 uv run pytest \
  tests/unit/train/tree/test_tree_pack.py \
  tests/unit/train/tree/test_tree_equivalence.py \
  -q

Observed locally before the latest current-RL baseline commit:

ruff: all checks passed
pytest: 19 passed

Latest current-RL baseline commit validation:

uv run ruff check \
  src/prime_rl/configs/sft.py \
  src/prime_rl/trainer/sft/data.py \
  tests/unit/train/tree/test_tree_pack.py

uv run pytest tests/unit/train/tree/test_tree_pack.py -q

uv run python -c "import tomllib; from prime_rl.configs.sft import SFTConfig; data = tomllib.load(open('configs/tree_ablation/intellect5_general_agent_current_rl_baseline.toml', 'rb')); cfg = SFTConfig.model_validate(data); print(cfg.data.type, cfg.data.subset, cfg.data.selection_metric)"

Observed after the latest current-RL baseline commit:

ruff: all checks passed
pytest: 19 passed
config validation: sft_raw_tool_current_rl_baseline general_agent_rlm branching_score

GPU correctness/perf gates from the earlier Tree Training v1.1 work:

UV_NO_SYNC=1 uv run pytest tests/unit/train/tree/test_tree_flex.py -q -m gpu
UV_NO_SYNC=1 uv run pytest tests/perf/test_tree_attention_speedup.py -v -m "gpu and perf"

Repro: 2-GPU High-Branching Ablation

Hardware used for the run below:

2x NVIDIA RTX PRO 6000 Blackwell Server Edition, 96 GiB each
Model: PrimeIntellect/Qwen3-0.6B
Precision: float32
Dataset: PrimeIntellect/INTELLECT-5-SFT-Raw / general_agent_rlm
Steps: 10
Warmup excluded: step 0

Note: --deployment.num-gpus is the SFT launcher GPU-rank count. It is not an inference deployment flag.

Tree FlexAttention:

CUDA_VISIBLE_DEVICES=0,1 \
UV_NO_SYNC=1 \
TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_tree_flex.toml \
  --deployment.num-gpus 2

Per-branch baseline:

CUDA_VISIBLE_DEVICES=0,1 \
UV_NO_SYNC=1 \
TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_baseline.toml \
  --deployment.num-gpus 2

Observed local results:

Tree FlexAttention:
  median step time excluding step 0: 3.18s
  final loss:                         0.4212
  peak memory:                        45.8 GiB

Per-branch baseline:
  median step time excluding step 0: 11.91s
  final loss:                         0.6017
  peak memory:                        34.3 GiB

Raw wall-clock step speedup:          ~3.75x

Do not use the logged perf/throughput as the headline metric for this comparison. It assumes batch_size * seq_len, which is not the right logical token count for variable tree-vs-branch workloads. Wall-clock step time and branch coverage are the useful metrics here.

The per-branch baseline also covers fewer logical tree branches per optimizer step. The selected trees have mean K = 20.42 leaves. With the local configs, tree-flex trains on 4 full trees per optimizer step, while the per-branch baseline trains on 16 individual branches per optimizer step. Normalized by logical branch coverage, the tree path is substantially stronger than the raw 3.75x step-time comparison.

Repro: 8xH200 Cluster

The current configs were written for the 2-GPU local run. For 8 GPUs, batch_size must be divisible by world_size * micro_batch_size, so override batch sizes on the command line.

Tree FlexAttention on 8 GPUs:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
UV_NO_SYNC=1 \
TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_tree_flex.toml \
  --deployment.num-gpus 8 \
  --data.batch-size 8

Current-RL flat-sample baseline on 8 GPUs:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
UV_NO_SYNC=1 \
TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_current_rl_baseline.toml \
  --deployment.num-gpus 8 \
  --data.batch-size 32

Per-branch baseline on 8 GPUs:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
UV_NO_SYNC=1 \
TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_baseline.toml \
  --deployment.num-gpus 8 \
  --data.batch-size 32

Grouped branch-equivalence baseline on 8 GPUs:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
UV_NO_SYNC=1 \
TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_equivalence_baseline.toml \
  --deployment.num-gpus 8 \
  --data.batch-size 8

If grouped equivalence OOMs, retry with activation checkpointing:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
UV_NO_SYNC=1 \
TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_equivalence_baseline.toml \
  --deployment.num-gpus 8 \
  --data.batch-size 8 \
  --model.ac

Grouped Equivalence Baseline Caveat

The grouped baseline is the mathematically clean per-tree baseline: for one tree with K leaves, it forwards every root-to-leaf branch and combines branch losses with weight 1/K before the optimizer step.

This is useful for correctness, but the current implementation is not a practical high-branching perf baseline. It accumulates all branch autograd graphs before a single backward. On the local 2x96 GiB machine:

  • grouped baseline without activation checkpointing: OOM on step 0
  • grouped baseline with --model.ac: still did not finish step 0 after several minutes, so it was killed

For practical apples-to-apples runs against the current production behavior, use sft_raw_tool_current_rl_baseline; it follows the current RL sample construction instead of the tree-equivalence branch set.

Validation Performed

UV_NO_SYNC=1 uv run ruff check \
  src/prime_rl/configs/sft.py \
  src/prime_rl/trainer/sft/data.py \
  tests/unit/train/tree/test_tree_pack.py

UV_NO_SYNC=1 uv run pytest \
  tests/unit/train/tree/test_tree_pack.py \
  tests/unit/train/tree/test_tree_equivalence.py \
  -q

NEVER_CLEAN_OUTPUT_DIR=1 UV_NO_SYNC=1 uv run sft \
  @ configs/tree_ablation/intellect5_general_agent_tree_flex.toml \
  --dry-run

NEVER_CLEAN_OUTPUT_DIR=1 UV_NO_SYNC=1 uv run sft \
  @ configs/tree_ablation/intellect5_general_agent_baseline.toml \
  --dry-run

NEVER_CLEAN_OUTPUT_DIR=1 UV_NO_SYNC=1 uv run sft \
  @ configs/tree_ablation/intellect5_general_agent_equivalence_baseline.toml \
  --dry-run

Latest current-RL baseline commit:

uv run ruff check \
  src/prime_rl/configs/sft.py \
  src/prime_rl/trainer/sft/data.py \
  tests/unit/train/tree/test_tree_pack.py

uv run pytest tests/unit/train/tree/test_tree_pack.py -q

uv run python -c "import tomllib; from prime_rl.configs.sft import SFTConfig; data = tomllib.load(open('configs/tree_ablation/intellect5_general_agent_current_rl_baseline.toml', 'rb')); cfg = SFTConfig.model_validate(data); print(cfg.data.type, cfg.data.subset, cfg.data.selection_metric)"

GPU runs performed locally:

CUDA_VISIBLE_DEVICES=0,1 UV_NO_SYNC=1 TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_tree_flex.toml \
  --deployment.num-gpus 2

CUDA_VISIBLE_DEVICES=0,1 UV_NO_SYNC=1 TOKENIZERS_PARALLELISM=false \
uv run sft @ configs/tree_ablation/intellect5_general_agent_baseline.toml \
  --deployment.num-gpus 2

Known unrelated warning seen during SFT dry-runs and launches:

[ERROR] `temperature` is part of GptOssForCausalLM.forward's signature, but not documented.

This warning predates this PR path and did not block config parsing or training.

rasdani and others added 13 commits May 1, 2026 20:15
… integration

Adds a tree-aware SFT path (gated by `data.type == "caterpillar_fake"`) that packs
caterpillar-shaped trajectories into a single sequence and computes per-token loss
weighted by g_t/K to match per-branch training. Standalone unit tests assert FP64
gradient equivalence to the per-branch baseline; integration test runs end-to-end
through the SFT trainer with HF dense Qwen3 + SDPA.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Side-by-side training script that runs Qwen3-0.6B with both tree mode and
per-branch baseline on identical caterpillar data and compares loss curves +
parameter drift over many steps. 200-step run shows max per-step loss rel diff
of 1.2e-7 and bounded ~1e-4 param drift, validating the v1 implementation at
real-model training scale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…blation

Adds the per-branch counterpart of CaterpillarFakeDataset that yields each
leaf's root-to-leaf path as a flat SFT sample, sharing the same caterpillar
fixture so tree-mode and baseline runs see identical synthetic data unrolled
differently. Adds two TOMLs (configs/tree_ablation/{tree,baseline}.toml)
plus loss/sum and loss/token_count log fields so wandb can render directly
comparable curves.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two related fixes that made the per-branch baseline non-comparable with the
tree mode in the wandb ablation:

- Caterpillar trees have num_turns + 1 leaves (one think branch per turn plus
  the final response with no children). The dataset hardcoded K = num_turns,
  silently dropping the final-response leaf and visiting each tree K times
  instead of K+1 times.
- setup_dataloader's cat packing concatenated multiple branches per micro-batch,
  so batch_size in the config no longer mapped to branches-per-step. Bypass
  packing for caterpillar_per_branch, mirroring the existing caterpillar_fake
  behavior, so each grad-accum step processes exactly one branch.

baseline.toml batch_size bumped from 32 to 40 (= 8 trees * 5 leaves) to match
tree.toml's per-step workload.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds SFTCaterpillarDataset and SFTCaterpillarPerBranchDataset that wrap any HF
SFT dataset whose assistant messages contain <think>...</think>. Each example
becomes a 1-turn caterpillar: user prompt as the trunk node, <think>...</think>
as a leaf, post-</think> content as a sibling leaf. Skips overlong samples
(> seq_len). Ablation configs target PrimeIntellect/INTELLECT-3-SFT-10K math
split with Qwen3-0.6B at seq_len=4096, 200 steps. Smoke tests show clean K=2
loss-sum ratio between tree and baseline modes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implements PLAN2.md: replace SDPA + materialized [N,N] mask with PyTorch FlexAttention's
BlockMask predicate so the ancestor-only mask compiles to a fused Triton kernel that skips
fully-False blocks (sibling regions) and fast-paths fully-True blocks (trunk attention).

- Add `model.attn = "flex_attention"` to AttnImplementation literal; broaden the tree
  validator to accept it.
- Extend PackedTree with node_of_token and is_ancestor_node helpers.
- New flex_mask.py: tree_mask_mod predicate + build_tree_block_mask(packed) constructor.
- Thread helpers through TreeSample/Batch/cat_collate and SFT compute_loss; build the
  BlockMask once per tree shape (cached).
- SDPA tree path stays the canonical reference — bit-identical to v1.
- New tests/unit/train/tree/test_tree_flex.py: flex-vs-sdpa correctness regression.
- New tests/perf/test_tree_attention_speedup.py: speedup guard tests.
- New configs/tree_ablation/realistic_tree_flex.toml: realistic ablation config.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous tests/unit/train/tree/_tiny_model.py called flex_attention(...) directly
which falls back to PyTorch's unfused, materialized-scores path. That path is slower
than SDPA's math backend due to extra dispatch overhead, and made the perf regression
tests fail (3.3× slower than SDPA, 2.2× slower than per-branch baseline). Wrap
flex_attention with torch.compile once at module load and route BlockMask calls
through it. Production HF Qwen3 already uses transformers' compiled wrapper — this
only affected the perf test harness.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Compiled FlexAttention (introduced in 72475ed) requires head_dim >= 16, while the
previous fixture used hidden_size=64 with num_heads=8 → head_dim=8, which is fine in
the eager FlexAttention path but raises an InductorError once torch.compile is wrapped
around flex_attention. Reduce num_heads to 4 to get head_dim=16 with the same hidden
size; same parameter count, same correctness scope.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant