You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As far as I understand, FlashInfer attention APIs are primarily designed for the standard prompt prefill case, where attention is computed over a contiguous suffix of the prompt (with KV cache already materialized for all preceding tokens).
I have a use case where KV cache is missing for multiple disjoint contiguous spans of a prompt, and I would like to compute attention only for those spans (and materialize/save their KV), while reusing the existing KV cache for the rest of the prompt.
Formalization
Given a prompt of length n with tokens:
t₀, t₁, …, tₙ₋₁
Define indices:
i₀, i₁, …, i₂ₖ₋₁ ∈ {0, …, n−1}
such that:
For every j ∈ {0, …, k−1}: i₂ⱼ ≤ i₂ⱼ₊₁
({t_x | x ∈ [i₂ⱼ,i₂ⱼ₊₁]} is a contiguous span of tokens)
For every j ∈ {1, …, k−1}: i₂ⱼ₋₁+1 < i₂ⱼ
(spans are disjoint and ordered)
For every x ∈ {0, …, n−1}: tₓ does not have KV cache iffx ∈ [i₂ⱼ,i₂ⱼ₊₁] for some j ∈ {0, …, k−1}
In words:
There are k disjoint contiguous spans of tokens {t_x | x ∈ [i₂ⱼ,i₂ⱼ₊₁]}, j ∈ {0, …, k−1} in the prompt for which KV cache is missing. Attention needs to be computed for all tokens in those spans, and their KV cache should be materialized and saved. All other tokens already have valid KV cache entries and should be reused.
What I’m Looking For
An attention backend that can efficiently handle this segmented prefill case, ideally without invoking a separate prefill kernel per span.
Baselines / Possible Approaches
One prefill per span
Invoke existing prefill kernels independently for each missing span {t_x | x ∈ [i₂ⱼ,i₂ⱼ₊₁]} using the kv cache for {t_x | x ∈ [0,i₂ⱼ₋₁]}. This is correct but may incur overhead from multiple kernel launches and repeated setup, especially when k is large or spans are small.
Query compaction with positional masking
Pack all query tokens from the missing spans into a contiguous query buffer, and provide an attention mask so that each query attends as if it were at its original position in the prompt, using the correct KV ranges. This would allow reuse of existing kernels while preserving correctness.
Does FlashInfer currently support an efficient mechanism for this “segmented prefill” / multiple disjoint span pattern?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
As far as I understand, FlashInfer attention APIs are primarily designed for the standard prompt prefill case, where attention is computed over a contiguous suffix of the prompt (with KV cache already materialized for all preceding tokens).
I have a use case where KV cache is missing for multiple disjoint contiguous spans of a prompt, and I would like to compute attention only for those spans (and materialize/save their KV), while reusing the existing KV cache for the rest of the prompt.
Formalization
Given a prompt of length
nwith tokens:Define indices:
such that:
For every
j ∈ {0, …, k−1}:i₂ⱼ ≤ i₂ⱼ₊₁(
{t_x | x ∈ [i₂ⱼ,i₂ⱼ₊₁]}is a contiguous span of tokens)For every
j ∈ {1, …, k−1}:i₂ⱼ₋₁+1 < i₂ⱼ(spans are disjoint and ordered)
For every
x ∈ {0, …, n−1}:tₓdoes not have KV cache iffx ∈ [i₂ⱼ,i₂ⱼ₊₁]for somej ∈ {0, …, k−1}In words:
There are
kdisjoint contiguous spans of tokens{t_x | x ∈ [i₂ⱼ,i₂ⱼ₊₁]}, j ∈ {0, …, k−1}in the prompt for which KV cache is missing. Attention needs to be computed for all tokens in those spans, and their KV cache should be materialized and saved. All other tokens already have valid KV cache entries and should be reused.What I’m Looking For
An attention backend that can efficiently handle this segmented prefill case, ideally without invoking a separate prefill kernel per span.
Baselines / Possible Approaches
One prefill per span
Invoke existing prefill kernels independently for each missing span
{t_x | x ∈ [i₂ⱼ,i₂ⱼ₊₁]}using the kv cache for{t_x | x ∈ [0,i₂ⱼ₋₁]}. This is correct but may incur overhead from multiple kernel launches and repeated setup, especially when k is large or spans are small.Query compaction with positional masking
Pack all query tokens from the missing spans into a contiguous query buffer, and provide an attention mask so that each query attends as if it were at its original position in the prompt, using the correct KV ranges. This would allow reuse of existing kernels while preserving correctness.
Does FlashInfer currently support an efficient mechanism for this “segmented prefill” / multiple disjoint span pattern?
Thanks in advance.
Beta Was this translation helpful? Give feedback.
All reactions