Skip to content

[Neuron] Static-shape generation loop for compilation-friendly inference #44742

@dacorvo

Description

@dacorvo

Context

GenerationMixin._sample grows input_ids, attention_mask, position_ids, and cache_position via torch.cat on every decode step. This is problematic for any backend where dynamic tensor shapes carry a cost:

  • XLA/torch.compile backends: Static shapes are required for graph caching — dynamic shapes cause retracing. For instance, on Neuron (Trainium/Inferentia), each new shape triggers a full NEFF recompilation (2–60s per step), making generation unusable without workarounds
  • CUDA/AMD/XPU: In theory, dynamic allocations could prevent optimal memory planning and fragment the allocator. In practice, benchmarks on CUDA (A10G, Llama-3.1-8B, torch 2.10, 256 new tokens) show no measurable difference in throughput (~32 tok/s) or peak memory (15.4 GB) between the dynamic and static-shape decode loops — the torch.compile inductor backend already optimizes the torch.cat ops effectively

StaticCache already exists in transformers and implies "I want fixed-size tensors" — but _sample completely ignores the cache type. All tensor management is unconditionally dynamic regardless of whether a StaticCache or DynamicCache is used.

This issue tracks adding StaticCache-aware static-shape generation to _sample, benefiting all devices that use static shapes.


Current state

The static-shape generation approach has been validated on Neuron using model.generate(custom_generate=neuron_sample) — the official generate() extension point for pluggable decoding. generate() handles input preparation, cache setup, logits processors, and stopping criteria; neuron_sample replaces _sample with a static-shape decode loop (~600 lines). The changes below would make this custom callable unnecessary by adding static-shape support to the default _sample.

Performance numbers measured on Llama-3.1-8B at TP=2 on trn2.3xlarge, 1024-token prompt, 128 new tokens.


What must change for static shapes

The default _sample (183 lines) has these dynamic-shape points:

Location What grows Static equivalent
_sample L2812 input_ids = torch.cat([input_ids, next_tokens], dim=-1) Write to pre-allocated buffer at current_position
_update_model_kwargs L929 attention_mask = torch.cat([mask, ones], dim=-1) Fixed-shape 4D mask, in-place scatter_
_update_model_kwargs L923 position_ids = torch.cat(...) Explicit from cache_position
_update_model_kwargs L940 cache_position = torch.cat(...) Scalar increment
_sample L2755 while self._has_unfinished_sequences(...) Fixed-count for loop (max_new_tokens steps)

The natural trigger for the static path is isinstance(model_kwargs["past_key_values"], StaticCache) — no new config flag needed.

The table above covers the decode loop. Four other parts of the generation flow also need static-shape changes:

Auto-compilation gate missing Neuron (critical)

_valid_auto_compile_criteria() (generation/utils.py L2021) gates auto-compilation on device.type in ["cuda", "xpu"]. Neuron is excluded, so torch.compile never triggers automatically — even when StaticCache is used (which sets is_compileable = True). Fix: add "neuron" to the valid hardware list (or derive it from StaticCache presence).

Attention dispatch decode degradation (critical)

Conditional dispatch logic (kernel eligibility check, GQA handling) gets traced into the Neuron compilation graph even when the custom kernel path is never executed at decode time. This causes ~30% decode throughput loss (29.5 → 22.8 tok/s). Needs design discussion — possible approaches: compile-time specialization for prefill vs decode attention, or allow registering separate attention implementations for each phase.

Static-shape chunked prefill (critical)

GenerationMixin._prefill grows the 2D attention mask every iteration (attention_mask[:, :current_length]), changing shapes per chunk and triggering a fresh compilation per chunk. A static-shape chunked prefill path using a fixed-shape 4D causal mask (batch, 1, chunk_size, max_cache_len) updated in-place would let all chunks share one compiled graph.

Impact: 55% prefill time savings (1.38s on 5.56s total at 1024 tokens). Also reduces peak HBM by ~4×.

Static padding alignment with right-padding (low)

Variable prompt padding length triggers recompilation per unique length. Inputs should be padded to a static, compilation-friendly boundary (e.g., max_seq_len - max_new_tokens).

Right-padding is strongly preferred for static-shape generation: with a pre-allocated output buffer output_ids[batch, max_length], right-padded prompts place the actual tokens at positions 0..prompt_len-1, so each newly generated token is simply written at position prompt_len + step — a single indexed assignment. With left-padding, the prompt occupies positions pad_len..max_length-1, so new tokens must either be inserted by shifting all prompt tokens left, or the write position must account for a per-sequence offset, complicating the static loop. If inputs arrive left-padded, converting to right-padding before entering the static path avoids this complexity.

Prefill compilation (medium)

torch.compile traces only the decode forward pass. The prefill path runs as individual eager operations. Structuring the prefill forward for separate torch.compile reduces prefill time from 3.44s → 1.39s (2.5× faster), total generation 7.85s → 6.40s (18.5% faster).


Option A: New _static_sample method (recommended)

Add a separate _static_sample method, dispatched from generate() when StaticCache is detected.

Dispatch: In generate(), after GENERATION_MODES_MAPPING resolves to "_sample":

if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
    cache = model_kwargs.get("past_key_values")
    if isinstance(cache, StaticCache):
        decoding_method = getattr(type(self), "_static_sample")

Key differences from _sample:

  1. Pre-allocate output_ids = torch.full((batch, max_length), pad_token_id)
  2. Build 4D decode mask (batch, 1, 1, max_cache_len) — update via scatter_
  3. Fixed-count for i in range(max_new_tokens) loop
  4. Skip _update_model_kwargs_for_generation — update cache_position and position_ids directly
  5. Same logits_processor / stopping_criteria / streamer / return_dict logic as _sample

Pros:

  • Zero changes to _sample — existing tests unaffected
  • Clean separation of concerns; easy to review
  • Can be iterated without risking regressions in the default path

Cons:

  • ~60% code duplication with _sample (logits processing, token selection, output accumulation)
  • Must maintain both methods when upstream changes _sample

Non-regression tests:

  • All existing _sample tests pass unchanged (no modifications to _sample)
  • ~8-10 new test methods for _static_sample (basic, dict output, early EOS, greedy, logits processor, streamer)
  • Add cache_implementation="static" sweep to GenerationTesterMixin

Option B: Conditional paths inside _sample

Add isinstance(past_key_values, StaticCache) checks at the 5 growth points in _sample and _update_model_kwargs_for_generation.

Pros:

  • No code duplication — single method, shared logic
  • Changes to _sample automatically apply to both paths
  • Smaller diff (~45 lines vs ~120 lines)

Cons:

  • Modifies the critical _sample path — every generation model runs through this
  • _update_model_kwargs_for_generation is used by ALL decoding methods (_beam_search, _assisted_decoding) — conditional there affects everything
  • Risk of subtle regressions from touching hot code

Non-regression tests:

  • ALL existing generation tests must be re-run_sample is modified
  • 150+ model test suites × ~6 generation tests each = ~900 test runs
  • Changes to _update_model_kwargs_for_generation also require beam search regression tests

Current workaround

~600 lines of custom generation logic (_neuron_chunked_prefill + neuron_sample) passed via model.generate(custom_generate=neuron_sample). Works but duplicates upstream logic and cannot benefit from new features added to the standard path.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions