-
Notifications
You must be signed in to change notification settings - Fork 32.6k
Description
Context
GenerationMixin._sample grows input_ids, attention_mask, position_ids, and cache_position via torch.cat on every decode step. This is problematic for any backend where dynamic tensor shapes carry a cost:
- XLA/torch.compile backends: Static shapes are required for graph caching — dynamic shapes cause retracing. For instance, on Neuron (Trainium/Inferentia), each new shape triggers a full NEFF recompilation (2–60s per step), making generation unusable without workarounds
- CUDA/AMD/XPU: In theory, dynamic allocations could prevent optimal memory planning and fragment the allocator. In practice, benchmarks on CUDA (A10G, Llama-3.1-8B, torch 2.10, 256 new tokens) show no measurable difference in throughput (~32 tok/s) or peak memory (15.4 GB) between the dynamic and static-shape decode loops — the
torch.compileinductor backend already optimizes thetorch.catops effectively
StaticCache already exists in transformers and implies "I want fixed-size tensors" — but _sample completely ignores the cache type. All tensor management is unconditionally dynamic regardless of whether a StaticCache or DynamicCache is used.
This issue tracks adding StaticCache-aware static-shape generation to _sample, benefiting all devices that use static shapes.
Current state
The static-shape generation approach has been validated on Neuron using model.generate(custom_generate=neuron_sample) — the official generate() extension point for pluggable decoding. generate() handles input preparation, cache setup, logits processors, and stopping criteria; neuron_sample replaces _sample with a static-shape decode loop (~600 lines). The changes below would make this custom callable unnecessary by adding static-shape support to the default _sample.
Performance numbers measured on Llama-3.1-8B at TP=2 on trn2.3xlarge, 1024-token prompt, 128 new tokens.
What must change for static shapes
The default _sample (183 lines) has these dynamic-shape points:
| Location | What grows | Static equivalent |
|---|---|---|
_sample L2812 |
input_ids = torch.cat([input_ids, next_tokens], dim=-1) |
Write to pre-allocated buffer at current_position |
_update_model_kwargs L929 |
attention_mask = torch.cat([mask, ones], dim=-1) |
Fixed-shape 4D mask, in-place scatter_ |
_update_model_kwargs L923 |
position_ids = torch.cat(...) |
Explicit from cache_position |
_update_model_kwargs L940 |
cache_position = torch.cat(...) |
Scalar increment |
_sample L2755 |
while self._has_unfinished_sequences(...) |
Fixed-count for loop (max_new_tokens steps) |
The natural trigger for the static path is isinstance(model_kwargs["past_key_values"], StaticCache) — no new config flag needed.
The table above covers the decode loop. Four other parts of the generation flow also need static-shape changes:
Auto-compilation gate missing Neuron (critical)
_valid_auto_compile_criteria() (generation/utils.py L2021) gates auto-compilation on device.type in ["cuda", "xpu"]. Neuron is excluded, so torch.compile never triggers automatically — even when StaticCache is used (which sets is_compileable = True). Fix: add "neuron" to the valid hardware list (or derive it from StaticCache presence).
Attention dispatch decode degradation (critical)
Conditional dispatch logic (kernel eligibility check, GQA handling) gets traced into the Neuron compilation graph even when the custom kernel path is never executed at decode time. This causes ~30% decode throughput loss (29.5 → 22.8 tok/s). Needs design discussion — possible approaches: compile-time specialization for prefill vs decode attention, or allow registering separate attention implementations for each phase.
Static-shape chunked prefill (critical)
GenerationMixin._prefill grows the 2D attention mask every iteration (attention_mask[:, :current_length]), changing shapes per chunk and triggering a fresh compilation per chunk. A static-shape chunked prefill path using a fixed-shape 4D causal mask (batch, 1, chunk_size, max_cache_len) updated in-place would let all chunks share one compiled graph.
Impact: 55% prefill time savings (1.38s on 5.56s total at 1024 tokens). Also reduces peak HBM by ~4×.
Static padding alignment with right-padding (low)
Variable prompt padding length triggers recompilation per unique length. Inputs should be padded to a static, compilation-friendly boundary (e.g., max_seq_len - max_new_tokens).
Right-padding is strongly preferred for static-shape generation: with a pre-allocated output buffer output_ids[batch, max_length], right-padded prompts place the actual tokens at positions 0..prompt_len-1, so each newly generated token is simply written at position prompt_len + step — a single indexed assignment. With left-padding, the prompt occupies positions pad_len..max_length-1, so new tokens must either be inserted by shifting all prompt tokens left, or the write position must account for a per-sequence offset, complicating the static loop. If inputs arrive left-padded, converting to right-padding before entering the static path avoids this complexity.
Prefill compilation (medium)
torch.compile traces only the decode forward pass. The prefill path runs as individual eager operations. Structuring the prefill forward for separate torch.compile reduces prefill time from 3.44s → 1.39s (2.5× faster), total generation 7.85s → 6.40s (18.5% faster).
Option A: New _static_sample method (recommended)
Add a separate _static_sample method, dispatched from generate() when StaticCache is detected.
Dispatch: In generate(), after GENERATION_MODES_MAPPING resolves to "_sample":
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
cache = model_kwargs.get("past_key_values")
if isinstance(cache, StaticCache):
decoding_method = getattr(type(self), "_static_sample")Key differences from _sample:
- Pre-allocate
output_ids = torch.full((batch, max_length), pad_token_id) - Build 4D decode mask
(batch, 1, 1, max_cache_len)— update viascatter_ - Fixed-count
for i in range(max_new_tokens)loop - Skip
_update_model_kwargs_for_generation— updatecache_positionandposition_idsdirectly - Same logits_processor / stopping_criteria / streamer / return_dict logic as
_sample
Pros:
- Zero changes to
_sample— existing tests unaffected - Clean separation of concerns; easy to review
- Can be iterated without risking regressions in the default path
Cons:
- ~60% code duplication with
_sample(logits processing, token selection, output accumulation) - Must maintain both methods when upstream changes
_sample
Non-regression tests:
- All existing
_sampletests pass unchanged (no modifications to_sample) - ~8-10 new test methods for
_static_sample(basic, dict output, early EOS, greedy, logits processor, streamer) - Add
cache_implementation="static"sweep toGenerationTesterMixin
Option B: Conditional paths inside _sample
Add isinstance(past_key_values, StaticCache) checks at the 5 growth points in _sample and _update_model_kwargs_for_generation.
Pros:
- No code duplication — single method, shared logic
- Changes to
_sampleautomatically apply to both paths - Smaller diff (~45 lines vs ~120 lines)
Cons:
- Modifies the critical
_samplepath — every generation model runs through this _update_model_kwargs_for_generationis used by ALL decoding methods (_beam_search,_assisted_decoding) — conditional there affects everything- Risk of subtle regressions from touching hot code
Non-regression tests:
- ALL existing generation tests must be re-run —
_sampleis modified - 150+ model test suites × ~6 generation tests each = ~900 test runs
- Changes to
_update_model_kwargs_for_generationalso require beam search regression tests
Current workaround
~600 lines of custom generation logic (_neuron_chunked_prefill + neuron_sample) passed via model.generate(custom_generate=neuron_sample). Works but duplicates upstream logic and cannot benefit from new features added to the standard path.
Related
- Non-generation Neuron changes: [Neuron] Improve transformers compatibility with AWS Neuron devices #44741
- Auto-StaticCache on Neuron device: [Neuron] Auto-select StaticCache when device is Neuron #44748 (depends on this issue)