[Speculative Decoding] Refactor EAGLE3 training to YAML-based config and recipe system#1134
[Speculative Decoding] Refactor EAGLE3 training to YAML-based config and recipe system#1134
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThe speculative decoding example migrates from CLI-based argument passing to YAML configuration-driven training. Training now uses Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error)
✅ Passed checks (3 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1134 +/- ##
==========================================
- Coverage 70.19% 70.17% -0.03%
==========================================
Files 230 230
Lines 26044 26053 +9
==========================================
+ Hits 18281 18282 +1
- Misses 7763 7771 +8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)
258-259:⚠️ Potential issue | 🔴 CriticalAdd
weights_only=Truetotorch.load()call for security.The
torch.load(data_args.draft_vocab_cache)at line 258 does not specifyweights_only=True, which allows arbitrary code execution from malicious pickle files. Sinced2tis a pure tensor (int64),weights_only=Trueis both safe and compatible.Proposed fix
- model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) + model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 258 - 259, The torch.load call that assigns model.eagle_module.d2t from data_args.draft_vocab_cache should pass weights_only=True to avoid executing pickled code; update the load call in the code that sets model.eagle_module.d2t to use torch.load(data_args.draft_vocab_cache, weights_only=True) so only tensor data is deserialized.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/launch_train.sh`:
- Around line 49-79: The accelerate launch invocation currently interpolates
unquoted variables into sh -c and omits --num_processes on single-node runs; to
fix, build the command as a bash array (e.g., CMD=()) instead of a single sh -c
string, append required flags using the existing
MULTI_NODE_ARGS/MODEL_ARG/TOTAL_GPU symbols (ensure MULTI_NODE_ARGS always
includes "--num_processes $TOTAL_GPU" even for single-node), and then run the
launch with "${CMD[@]}" so that $CONFIG_FILE, $MODEL, $HEAD_NODE_IP and other
variables are safely quoted and preserved without word-splitting or accidental
expansion.
In `@examples/speculative_decoding/main.py`:
- Line 111: The metadata help string for the dataclass field ar_validate_steps
is incomplete; update the metadata["help"] for ar_validate_steps to a full,
descriptive sentence (e.g., "Number of autoregressive validation steps to run
during evaluation" or similar) so users understand its purpose; locate the
ar_validate_steps field definition and replace the truncated help text with the
completed description.
In `@examples/speculative_decoding/train_eagle3_and_export.sh`:
- Around line 43-48: train_config.yaml is missing the base model identifier so
the generated YAML is not replayable; update the code that writes YAML_FILE
(train_config.yaml) to include the model_name_or_path value (the model used via
the --model override) under model: (e.g., model_name_or_path: "<value>") so the
config fully captures the runtime model selection; ensure the string comes from
the same variable/arg used to parse the --model override and is written when
creating YAML_FILE (preserving YAML_FILE, OUTPUT_DIR, and model_name_or_path
references).
In `@modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml`:
- Around line 3-6: The recipe currently enables trust_remote_code by default in
the model block (fields model_name_or_path: moonshotai/Kimi-K2.5 and
trust_remote_code: true); change that default to false and instead
document/require an explicit opt-in (e.g., a commented flag or
environment-driven toggle) so users must consciously enable trust_remote_code
for the Kimi recipe; update any README or inline comment near the model
configuration and/or the use_fake_base_for_offline handling so it explains how
to opt in (enable trust_remote_code) when the user intentionally trusts the
model's custom HF code.
In `@modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml`:
- Around line 4-6: Replace the unsafe default by changing the YAML key
trust_remote_code from true to false in the model block of the Llama offline
recipe (the block containing model_name_or_path: meta-llama/Llama-3.2-1B);
update the value so the recipe does not silently enable remote code execution
and leave a brief comment if you want to document that users must opt-in to
enable remote code loading manually.
In `@modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml`:
- Around line 4-6: The YAML enables trust_remote_code for a stock Llama model;
remove or set trust_remote_code to false to avoid executing arbitrary repo code.
Edit the model block that contains model_name_or_path: meta-llama/Llama-3.2-1B
and either delete the trust_remote_code line or change it to trust_remote_code:
false so the pipeline uses the standard transformers implementation rather than
allowing remote code execution.
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 138-149: The test currently writes both mix_hidden_states variants
into the same output directory causing runs to clobber each other; modify the
training output_dir construction (where eagle_output_dir /
f"eagle-tinyllama-cp{cp_size}" is used) to include the mix_hidden_states flag
(e.g., append `_mix{mix_hidden_states}` or similar) so each (cp_size,
mix_hidden_states) combination gets a unique checkpoint directory; update any
references that assume the old path (e.g., test_resume_training) to use the new
per-variant output_dir variable.
- Around line 269-273: Parametrize the trust_remote_code flag instead of
hardcoding True: add a test parameter (default False) named trust_remote_code to
the relevant test cases and use it when writing the model YAML dictionary
(replace the hardcoded "trust_remote_code": True with "trust_remote_code":
trust_remote_code) and when calling AutoConfig.from_pretrained (replace the
hardcoded trust_remote_code=True with trust_remote_code=trust_remote_code);
update only the specific test invocations that require remote code execution to
pass trust_remote_code=True. Ensure the new parameter is included in the pytest
parametrization for the test function(s) that build the YAML/model config so
local models keep trust_remote_code=False while remote-model cases explicitly
set it to True.
---
Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 258-259: The torch.load call that assigns model.eagle_module.d2t
from data_args.draft_vocab_cache should pass weights_only=True to avoid
executing pickled code; update the load call in the code that sets
model.eagle_module.d2t to use torch.load(data_args.draft_vocab_cache,
weights_only=True) so only tensor data is deserialized.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c971e970-8ce6-4555-9bd5-f56f417bbb15
📒 Files selected for processing (11)
examples/speculative_decoding/README.mdexamples/speculative_decoding/eagle_config.jsonexamples/speculative_decoding/fsdp_config.jsonexamples/speculative_decoding/launch_train.shexamples/speculative_decoding/main.pyexamples/speculative_decoding/train_eagle3_and_export.shmodelopt_recipes/speculative_decoding/_base_eagle3.yamlmodelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yamlmodelopt_recipes/speculative_decoding/llama3_eagle_offline.yamlmodelopt_recipes/speculative_decoding/llama3_eagle_online.yamltests/examples/speculative_decoding/test_eagle.py
💤 Files with no reviewable changes (2)
- examples/speculative_decoding/eagle_config.json
- examples/speculative_decoding/fsdp_config.json
| # GPU count detection | ||
| if [[ "$NUM_NODES" != "1" ]]; then | ||
| GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} | ||
| TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) | ||
| echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" | ||
| else | ||
| #Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES | ||
| TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())") | ||
| echo "Total GPUs: $TOTAL_GPU (Single Node Training)" | ||
| fi | ||
| # Calculate save_steps | ||
| DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) | ||
|
|
||
| MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} | ||
| MODE=${MODE:-"eagle3"} | ||
| EAGLE_DECODER_TYPE=${EAGLE_DECODER_TYPE:-"llama"} | ||
| # Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path | ||
| MODEL_BASENAME=$(basename "$MODEL") | ||
| OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"} | ||
| NUM_EPOCHS=${NUM_EPOCHS:-1} | ||
| SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} | ||
| LR=${LR:-"1e-4"} | ||
| TRAIN_BS=${TRAIN_BS:-1} | ||
| TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} | ||
| DATA=${DATA:-""} | ||
| OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} | ||
| DISABLE_TQDM=${DISABLE_TQDM:-False} | ||
| VLM_PROCESSOR=${VLM_PROCESSOR:-} | ||
| VLM_IMG_DIR=${VLM_IMG_DIR:-} | ||
| AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} | ||
| ESTIMATE_AR=${ESTIMATE_AR:-False} | ||
| CP_SIZE=${CP_SIZE:-1} | ||
| DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} | ||
| LOG_STEPS=${LOG_STEPS:-100} | ||
| DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} | ||
| MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} | ||
| DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"} | ||
| NUM_TTT_STEPS=${NUM_TTT_STEPS:-3} | ||
|
|
||
| USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"} | ||
| TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"} | ||
| FSDP=${FSDP:-"False"} | ||
|
|
||
| if [[ "$MODE" == "eagle3" ]]; then | ||
| if [[ -n "$EAGLE_CONFIG" ]]; then | ||
| SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" | ||
| else | ||
| SPECULATIVE_ARGS="" | ||
| fi | ||
| else | ||
| echo "Only eagle3 supported for now!" | ||
| exit 1 | ||
| fi | ||
|
|
||
| if [[ "$OFFLINE_DATA_PATH" != "" ]]; then | ||
| if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then | ||
| echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory." | ||
| exit 1 | ||
| else | ||
| DATA_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1" | ||
| fi | ||
| else | ||
| DATA_ARGS="--data_path $DATA" | ||
| fi | ||
|
|
||
|
|
||
| if [[ "$VLM_PROCESSOR" != "" ]]; then | ||
| VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR" | ||
| else | ||
| VLM_ARGS="" | ||
| fi | ||
|
|
||
| if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then | ||
| #Use FSDP2 when multi GPU available | ||
| FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" | ||
| else | ||
| #Otherwise, single GPU training | ||
| FSDP_ARGS="" | ||
| TOTAL_GPU=$(python3 -c "import torch; print(torch.cuda.device_count())") | ||
| echo "Total GPUs: $TOTAL_GPU (single node)" | ||
| fi | ||
|
|
||
|
|
||
| if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then | ||
| DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE" | ||
| else | ||
| DRAFT_VOCAB_CACHE_ARGS="" | ||
| fi | ||
|
|
||
| if [[ "$NUM_NODES" != 1 ]]; then | ||
| # Multi-node routing args (accelerate only; training config comes from the YAML) | ||
| MULTI_NODE_ARGS="" | ||
| if [[ "$NUM_NODES" != "1" ]]; then | ||
| MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ | ||
| --num_machines $NUM_NODES \ | ||
| --machine_rank $SLURM_PROCID \ | ||
| --rdzv_backend c10d \ | ||
| --main_process_ip $HEAD_NODE_IP \ | ||
| --main_process_port 29500" | ||
| else | ||
| MULTI_NODE_ARGS="" | ||
| fi | ||
|
|
||
| # Disable tokenizers parallelism to avoid warning | ||
| MODEL_ARG="" | ||
| if [ -n "$MODEL" ]; then | ||
| MODEL_ARG="--model $MODEL" | ||
| fi | ||
|
|
||
| export TOKENIZERS_PARALLELISM=False | ||
| CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/main.py \ | ||
| --mode $MODE \ | ||
| --eagle_decoder_type $EAGLE_DECODER_TYPE \ | ||
| --model_name_or_path $MODEL \ | ||
| --training_seq_len $TRAINING_SEQ_LEN \ | ||
| --dataloader_drop_last True \ | ||
| --bf16 True \ | ||
| --output_dir $OUTPUT_DIR \ | ||
| --num_train_epochs $NUM_EPOCHS \ | ||
| --per_device_train_batch_size $TRAIN_BS \ | ||
| --per_device_eval_batch_size $TRAIN_BS \ | ||
| --gradient_accumulation_steps 1 \ | ||
| --do_eval False \ | ||
| --eval_accumulation_steps 1 \ | ||
| --save_strategy steps \ | ||
| --save_steps $SAVE_STEPS \ | ||
| --learning_rate $LR \ | ||
| --weight_decay 0.0 \ | ||
| --warmup_steps 100 \ | ||
| --lr_scheduler_type linear \ | ||
| --logging_steps $LOG_STEPS \ | ||
| --tf32 True \ | ||
| $DATA_ARGS \ | ||
| --disable_tqdm $DISABLE_TQDM \ | ||
| --estimate_ar $ESTIMATE_AR \ | ||
| --ar_validate_steps $AR_VALIDATE_STEPS \ | ||
| --mix_hidden_states $MIX_HIDDEN_STATES \ | ||
| --disable_torch_compile $DISABLE_TORCH_COMPILE \ | ||
| --use_fake_base_for_offline $USE_FAKE_BASE_FOR_OFFLINE \ | ||
| --trust_remote_code $TRUST_REMOTE_CODE \ | ||
| $DRAFT_VOCAB_CACHE_ARGS \ | ||
| $VLM_ARGS \ | ||
| $SPECULATIVE_ARGS \ | ||
| $FSDP_ARGS \ | ||
| --cp_size $CP_SIZE \ | ||
| --dp_shard_size $DP_SHARD_SIZE \ | ||
| --num_ttt_steps $NUM_TTT_STEPS \ | ||
| " | ||
|
|
||
| set -x | ||
| start_time=$(date +%s) | ||
| sh -c "$CMD" | ||
| echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" | ||
| sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE $MODEL_ARG" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and examine the file
find . -name "launch_train.sh" -path "*/speculative_decoding/*" 2>/dev/null | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 113
🏁 Script executed:
# Read the file to verify lines 49-79
if [ -f "examples/speculative_decoding/launch_train.sh" ]; then
echo "=== Full file context ==="
wc -l examples/speculative_decoding/launch_train.sh
echo ""
echo "=== Lines 40-85 ==="
sed -n '40,85p' examples/speculative_decoding/launch_train.sh
fiRepository: NVIDIA/Model-Optimizer
Length of output: 1494
🏁 Script executed:
# Check if there are any quoting or safety mechanisms around the variables
if [ -f "examples/speculative_decoding/launch_train.sh" ]; then
echo "=== Searching for variable assignments and usage ==="
rg -n "(CONFIG_FILE|MODEL|HEAD_NODE_IP|TOTAL_GPU)" examples/speculative_decoding/launch_train.sh -B1 -A1
fiRepository: NVIDIA/Model-Optimizer
Length of output: 1760
Use bash arrays with proper quoting to build the accelerate launch command; always pass --num_processes for both single-node and multi-node runs.
Lines 62–79 interpolate $CONFIG_FILE, $MODEL, and $HEAD_NODE_IP unquoted into sh -c, so spaces or shell metacharacters in these values will break argument parsing or execute unintended commands. Additionally, the single-node path computes TOTAL_GPU (line 55) but never uses it; without --num_processes, Accelerate defaults to its own heuristic instead of the detected GPU count.
Fix: Use bash array to safely build command
-# Multi-node routing args (accelerate only; training config comes from the YAML)
-MULTI_NODE_ARGS=""
-if [[ "$NUM_NODES" != "1" ]]; then
- MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \
- --num_machines $NUM_NODES \
- --machine_rank $SLURM_PROCID \
- --rdzv_backend c10d \
- --main_process_ip $HEAD_NODE_IP \
- --main_process_port 29500"
-fi
-
-MODEL_ARG=""
-if [ -n "$MODEL" ]; then
- MODEL_ARG="--model $MODEL"
-fi
+# Build launch command as array
+LAUNCH_ARGS=(accelerate launch --mixed_precision bf16 --num_processes "$TOTAL_GPU")
+if [[ "$NUM_NODES" != "1" ]]; then
+ LAUNCH_ARGS+=(
+ --num_machines "$NUM_NODES"
+ --machine_rank "$SLURM_PROCID"
+ --rdzv_backend c10d
+ --main_process_ip "$HEAD_NODE_IP"
+ --main_process_port 29500
+ )
+fi
+LAUNCH_ARGS+=("${SCRIPT_DIR}/main.py" --config "$CONFIG_FILE")
+if [[ -n "$MODEL" ]]; then
+ LAUNCH_ARGS+=(--model "$MODEL")
+fi
export TOKENIZERS_PARALLELISM=False
set -x
start_time=$(date +%s)
-sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE $MODEL_ARG"
+"${LAUNCH_ARGS[@]}"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/launch_train.sh` around lines 49 - 79, The
accelerate launch invocation currently interpolates unquoted variables into sh
-c and omits --num_processes on single-node runs; to fix, build the command as a
bash array (e.g., CMD=()) instead of a single sh -c string, append required
flags using the existing MULTI_NODE_ARGS/MODEL_ARG/TOTAL_GPU symbols (ensure
MULTI_NODE_ARGS always includes "--num_processes $TOTAL_GPU" even for
single-node), and then run the launch with "${CMD[@]}" so that $CONFIG_FILE,
$MODEL, $HEAD_NODE_IP and other variables are safely quoted and preserved
without word-splitting or accidental expansion.
| default=False, metadata={"help": "Set to False to keep extra args for VLM."} | ||
| default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."} | ||
| ) | ||
| ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation ."}) |
There was a problem hiding this comment.
Incomplete help text for ar_validate_steps.
The metadata help string is truncated and ends abruptly with "AR validation ." Consider completing the description.
📝 Suggested fix
- ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation ."})
+ ar_validate_steps: int = field(default=1000, metadata={"help": "Steps interval for AR validation."})📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation ."}) | |
| ar_validate_steps: int = field(default=1000, metadata={"help": "Steps interval for AR validation."}) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/main.py` at line 111, The metadata help string
for the dataclass field ar_validate_steps is incomplete; update the
metadata["help"] for ar_validate_steps to a full, descriptive sentence (e.g.,
"Number of autoregressive validation steps to run during evaluation" or similar)
so users understand its purpose; locate the ar_validate_steps field definition
and replace the truncated help text with the completed description.
| # Write config to output dir so it's preserved alongside the checkpoint | ||
| YAML_FILE="$OUTPUT_DIR/train_config.yaml" | ||
| cat > "$YAML_FILE" << EOF | ||
| model: | ||
| use_fake_base_for_offline: false | ||
| trust_remote_code: false |
There was a problem hiding this comment.
Persist model_name_or_path in the generated YAML.
Line 43 says this file is preserved with the checkpoint, but the actual base model only exists in the separate --model override on Line 103. That makes train_config.yaml non-replayable by itself and weakens the whole “YAML as the source of truth” refactor.
♻️ Suggested fix
cat > "$YAML_FILE" << EOF
model:
+ model_name_or_path: "$BASE_MODEL"
use_fake_base_for_offline: false
trust_remote_code: false
@@
-./launch_train.sh --config "$YAML_FILE" --model "$BASE_MODEL"
+./launch_train.sh --config "$YAML_FILE"Also applies to: 103-103
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/train_eagle3_and_export.sh` around lines 43 -
48, train_config.yaml is missing the base model identifier so the generated YAML
is not replayable; update the code that writes YAML_FILE (train_config.yaml) to
include the model_name_or_path value (the model used via the --model override)
under model: (e.g., model_name_or_path: "<value>") so the config fully captures
the runtime model selection; ensure the string comes from the same variable/arg
used to parse the --model override and is written when creating YAML_FILE
(preserving YAML_FILE, OUTPUT_DIR, and model_name_or_path references).
| model: | ||
| model_name_or_path: moonshotai/Kimi-K2.5 | ||
| use_fake_base_for_offline: true | ||
| trust_remote_code: true |
There was a problem hiding this comment.
Avoid shipping the shared Kimi recipe with trust_remote_code already enabled.
If this model genuinely needs custom HF code, that should still be an explicit trusted opt-in rather than the checked-in default. As written, anyone running the recipe straight from the repo executes code from the model source by default.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml` around
lines 3 - 6, The recipe currently enables trust_remote_code by default in the
model block (fields model_name_or_path: moonshotai/Kimi-K2.5 and
trust_remote_code: true); change that default to false and instead
document/require an explicit opt-in (e.g., a commented flag or
environment-driven toggle) so users must consciously enable trust_remote_code
for the Kimi recipe; update any README or inline comment near the model
configuration and/or the use_fake_base_for_offline handling so it explains how
to opt in (enable trust_remote_code) when the user intentionally trusts the
model's custom HF code.
| model: | ||
| model_name_or_path: meta-llama/Llama-3.2-1B | ||
| trust_remote_code: true |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the target file and its base recipe
cat -n modelopt_recipes/speculative_decoding/llama3_eagle_online.yamlRepository: NVIDIA/Model-Optimizer
Length of output: 393
🏁 Script executed:
# Check the base recipe it inherits from
cat -n modelopt_recipes/speculative_decoding/_base_eagle3.yamlRepository: NVIDIA/Model-Optimizer
Length of output: 1691
🏁 Script executed:
# Search for other YAML files with trust_remote_code settings to understand patterns
rg "trust_remote_code" modelopt_recipes/ -A 2 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 1318
🏁 Script executed:
# Check if there are any other Llama recipe files for comparison
find modelopt_recipes -name "*llama*" -o -name "*Llama*" | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 192
🏁 Script executed:
# Search for any documentation or README files about speculative decoding recipes
find modelopt_recipes/speculative_decoding -name "README*" -o -name "*.md" | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check if there are any comments or documentation in the directory
cat modelopt_recipes/speculative_decoding/README.md 2>/dev/null || echo "No README found"Repository: NVIDIA/Model-Optimizer
Length of output: 81
🏁 Script executed:
# Search for any discussion or comments about trust_remote_code in the codebase
rg "trust_remote_code" --type python -B 3 -A 3 | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 96
🏁 Script executed:
# Check the main training script to understand how the model is loaded
find . -name "main.py" -path "*/speculative*" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 105
🏁 Script executed:
# Search for any imports or usage of Llama model to understand if custom modeling_* is needed
rg "meta-llama|Llama.*3.*2" --type python | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 96
🏁 Script executed:
# Read the main.py to understand how models are loaded
head -100 ./examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3606
🏁 Script executed:
# Search for trust_remote_code usage without specifying file type
rg "trust_remote_code" -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 50379
🏁 Script executed:
# Look for any model initialization code related to Llama models
rg "from_pretrained|load.*model" ./examples/speculative_decoding/main.py -A 2 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 882
🏁 Script executed:
# Check if there are any comments or documentation explaining why trust_remote_code is needed
rg "Llama.*3.*2.*1B|trust_remote_code.*true" modelopt_recipes/ -B 5 -A 2 | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 1993
🏁 Script executed:
# Search for any custom modeling code or patches that might explain the requirement
find modelopt -name "*llama*" -type f | grep -i modelRepository: NVIDIA/Model-Optimizer
Length of output: 110
🏁 Script executed:
# Check if there's any issue or discussion about Llama models requiring remote code
rg "Llama.*trust|custom.*code.*Llama" modelopt/ --type-not=binary -iRepository: NVIDIA/Model-Optimizer
Length of output: 96
🏁 Script executed:
# Check the mcore_llama.py file
head -50 modelopt/torch/export/plugins/mcore_llama.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1876
🏁 Script executed:
# Search for any custom modeling code in modelopt that might relate to Llama
rg "class.*Llama|def.*llama" modelopt/ -i -A 2 | head -60Repository: NVIDIA/Model-Optimizer
Length of output: 1475
🏁 Script executed:
# Let's see if there are any issues or comments about Llama-3.2 specifically
rg "3\.2.*1B|meta-llama" . --type yaml -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1817
Remove unnecessary trust_remote_code: true setting for the stock Llama model.
The meta-llama/Llama-3.2-1B model uses the standard transformers library implementation and does not require custom modeling code. Shipping the example with trust_remote_code: true unnecessarily enables arbitrary code execution from the model repository as the default behavior for users following the documented path.
Set trust_remote_code: false or remove the field entirely.
Suggested fix
model:
model_name_or_path: meta-llama/Llama-3.2-1B
- trust_remote_code: true
+ trust_remote_code: false📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| model: | |
| model_name_or_path: meta-llama/Llama-3.2-1B | |
| trust_remote_code: true | |
| model: | |
| model_name_or_path: meta-llama/Llama-3.2-1B | |
| trust_remote_code: false |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml` around lines
4 - 6, The YAML enables trust_remote_code for a stock Llama model; remove or set
trust_remote_code to false to avoid executing arbitrary repo code. Edit the
model block that contains model_name_or_path: meta-llama/Llama-3.2-1B and either
delete the trust_remote_code line or change it to trust_remote_code: false so
the pipeline uses the standard transformers implementation rather than allowing
remote code execution.
| "training": { | ||
| "output_dir": str(eagle_output_dir / f"eagle-tinyllama-cp{cp_size}"), | ||
| "num_train_epochs": 0.25, | ||
| "learning_rate": 1e-5, | ||
| "training_seq_len": 128, | ||
| "cp_size": cp_size, | ||
| "per_device_train_batch_size": 1, | ||
| }, | ||
| "eagle": { | ||
| "eagle_mix_hidden_states": mix_hidden_states, | ||
| "eagle_architecture_config": tiny_eagle_arch_config, | ||
| }, |
There was a problem hiding this comment.
Give each mix_hidden_states variant its own checkpoint directory.
Both (cp_size=1, mix_hidden_states=False) and (cp_size=1, mix_hidden_states=True) write to eagle-tinyllama-cp1 (same for cp_size=2), so the later case reuses or overwrites the earlier run instead of testing a clean configuration. That also makes test_resume_training depend on whichever variant ran last.
🧪 One simple way to isolate the matrix
"training": {
- "output_dir": str(eagle_output_dir / f"eagle-tinyllama-cp{cp_size}"),
+ "output_dir": str(
+ eagle_output_dir / f"eagle-tinyllama-cp{cp_size}-mix{int(mix_hidden_states)}"
+ ),
"num_train_epochs": 0.25,
"learning_rate": 1e-5,
"training_seq_len": 128,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/examples/speculative_decoding/test_eagle.py` around lines 138 - 149,
The test currently writes both mix_hidden_states variants into the same output
directory causing runs to clobber each other; modify the training output_dir
construction (where eagle_output_dir / f"eagle-tinyllama-cp{cp_size}" is used)
to include the mix_hidden_states flag (e.g., append `_mix{mix_hidden_states}` or
similar) so each (cp_size, mix_hidden_states) combination gets a unique
checkpoint directory; update any references that assume the old path (e.g.,
test_resume_training) to use the new per-variant output_dir variable.
| "model": { | ||
| "model_name_or_path": str(model_path), | ||
| "trust_remote_code": True, | ||
| "use_fake_base_for_offline": use_fake_base, | ||
| }, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file and surrounding context
cat -n tests/examples/speculative_decoding/test_eagle.py | sed -n '260,280p'Repository: NVIDIA/Model-Optimizer
Length of output: 897
🏁 Script executed:
# Check the full function to understand the context
cat -n tests/examples/speculative_decoding/test_eagle.py | sed -n '250,290p'Repository: NVIDIA/Model-Optimizer
Length of output: 1661
🏁 Script executed:
# Search for where this config is used and if trust_remote_code is parameterized
rg "trust_remote_code" tests/examples/speculative_decoding/test_eagle.py -B 5 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 1677
🏁 Script executed:
# Check if this config is written to a file and loaded elsewhere
rg -t py "offline_config|use_fake_base_for_offline" tests/examples/speculative_decoding/test_eagle.py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 565
🏁 Script executed:
# Check what happens with the yaml_file after it's written
cat -n tests/examples/speculative_decoding/test_eagle.py | sed -n '289,310p'Repository: NVIDIA/Model-Optimizer
Length of output: 1162
🏁 Script executed:
# Look for test function signatures and parameters
rg -t py "^def test_|@pytest.mark.parametrize" tests/examples/speculative_decoding/test_eagle.py -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 2786
🏁 Script executed:
# Check if there are any test parameters or fixtures for trust_remote_code
rg -t py "model_source|pytest.param" tests/examples/speculative_decoding/test_eagle.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 911
🏁 Script executed:
# Look at the broader test structure to find parametrization
head -100 tests/examples/speculative_decoding/test_eagle.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3696
Parametrize trust_remote_code instead of hardcoding it to True.
The test writes trust_remote_code: True into every offline training YAML regardless of whether the model is local (tiny_llama) or remote (e.g., moonshotai/Kimi-K2.5). This violates the security guideline: remote models executing with trust_remote_code=True creates an RCE vector if the model source is untrusted.
Thread trust_remote_code as an explicit test parameter defaulting to False, and only set it to True for test cases that specifically require it. Also apply the same fix to line 243 where AutoConfig.from_pretrained() is called with hardcoded trust_remote_code=True.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/examples/speculative_decoding/test_eagle.py` around lines 269 - 273,
Parametrize the trust_remote_code flag instead of hardcoding True: add a test
parameter (default False) named trust_remote_code to the relevant test cases and
use it when writing the model YAML dictionary (replace the hardcoded
"trust_remote_code": True with "trust_remote_code": trust_remote_code) and when
calling AutoConfig.from_pretrained (replace the hardcoded trust_remote_code=True
with trust_remote_code=trust_remote_code); update only the specific test
invocations that require remote code execution to pass trust_remote_code=True.
Ensure the new parameter is included in the pytest parametrization for the test
function(s) that build the YAML/model config so local models keep
trust_remote_code=False while remote-model cases explicitly set it to True.
|
So does the yaml file encode all the information modelopt needs for the eagle3 training? |
Basically yes. The only exception is the accelerate configs (e.g. multinode settings). They need to be passed in addition to the yaml config, e.g.: I think they are orthogonal to the "recipe" and is more convenient to set in this way, since the node ip is often dynamic on slurm jobs. Do you think it's better to put it also in the yaml? |
What does this PR do?
Refactors EAGLE3 training to use a unified YAML-based config system.
Type of change: Refactor
Changes
launch_train.shnow accepts--config <yaml>(required) and--model <path>(optional override). All other settings live in YAML.modelopt_recipes/speculative_decoding/with__base__inheritance support.eagle_config.jsonandfsdp_config.json; architecture config is now nested undereagle.eagle_architecture_configin YAML.examples/speculative_decoding/README.mdfor the new interface.Usage
Summary by CodeRabbit
Documentation
New Features