Skip to content

Commit 5e7dad2

Browse files
fix(paged_attention): fix O(N^2) thrashing + FCFS priority in PA scheduler
Reapply upstream fixes from PRs EricLBuehler#2031/EricLBuehler#2034: fix quadratic scheduling complexity when sequences are waiting, and add FCFS priority ordering to prevent starvation.
1 parent ad3cb23 commit 5e7dad2

1 file changed

Lines changed: 41 additions & 24 deletions

File tree

mistralrs-core/src/paged_attention/scheduler.rs

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,16 @@ impl PagedAttentionScheduler {
150150
return buckets.into_values().next().unwrap();
151151
}
152152

153-
// Find the bucket with the shortest sequence length
153+
// Find the bucket containing the OLDEST sequence (lowest timestamp) to ensure FCFS priority
154154
let min_key = *buckets
155-
.keys()
156-
.min_by_key(|(len, _, _)| *len)
155+
.iter()
156+
.min_by_key(|(_, seqs)| {
157+
seqs.iter()
158+
.map(|seq| get_mut_arcmutex!(seq).timestamp())
159+
.min()
160+
.unwrap()
161+
})
162+
.map(|(key, _)| key)
157163
.expect("No sequence buckets");
158164

159165
let selected = buckets.remove(&min_key).unwrap();
@@ -179,6 +185,8 @@ impl PagedAttentionScheduler {
179185
pub fn schedule(&mut self, logger: &IntervalLogger) -> PagedAttentionSchedulerOutput {
180186
let mut scheduled: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
181187
let mut for_waiting_again: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
188+
let mut batched_prompt_tokens = 0;
189+
let mut batched_sequences = 0;
182190
while !self.waiting.is_empty() {
183191
let mut did_ignore = false;
184192
let seq = self.waiting.front().unwrap().clone();
@@ -192,8 +200,16 @@ impl PagedAttentionScheduler {
192200
let tokens = seq_guard.get_toks().to_vec();
193201
let num_tokens = tokens.len();
194202
let mm_features = seq_guard.mm_features().to_vec();
203+
let num_new_tokens = num_tokens.saturating_sub(seq_guard.prefix_cache_len());
195204
drop(seq_guard);
196205

206+
// Halt batch mapping if context size approaches engine limits.
207+
if (batched_prompt_tokens + num_new_tokens > 16384 || batched_sequences >= 10) && batched_sequences > 0 {
208+
break;
209+
}
210+
batched_prompt_tokens += num_new_tokens;
211+
batched_sequences += 1;
212+
197213
// Compute block hashes for prefix cache lookup
198214
self.ensure_block_hashes(seq_id, &tokens, &mm_features);
199215
let block_hashes = self
@@ -232,8 +248,9 @@ impl PagedAttentionScheduler {
232248
*count += 1;
233249

234250
if *count > WAITING_TIMEOUT {
235-
// Try to preempt a running sequence
236-
if let Some(seq_to_preempt) = self.running.pop_back() {
251+
// Continuously preempt running sequences until allocation succeeds
252+
let mut success = false;
253+
while let Some(seq_to_preempt) = self.running.pop_back() {
237254
self._preempt(seq_to_preempt);
238255

239256
// Retry allocation
@@ -242,25 +259,28 @@ impl PagedAttentionScheduler {
242259
kv_mgr.allocate_slots(seq_id, num_tokens, &computed.block_ids);
243260
drop(kv_mgr);
244261

245-
if retry.is_none() {
262+
if retry.is_some() {
263+
self.waiting_counts.remove(&seq_id);
264+
success = true;
265+
break;
266+
}
267+
}
268+
269+
if !success {
270+
// Even after emptying `running`, it doesn't fit.
271+
if self.running.is_empty() {
246272
let id = seq_id;
247273
warn!(
248-
"Sequence {id} with length of {num_tokens} tokens still exceeds KV cache size \
249-
even after evicting another sequence.",
274+
"Sequence {id} with length of {num_tokens} tokens is too long and exceeds max KV cache size. \
275+
Ignored."
250276
);
251277
get_mut_arcmutex!(seq).set_state(SequenceState::FinishedIgnored);
252278
did_ignore = true;
253279
} else {
254-
self.waiting_counts.remove(&seq_id);
280+
warn!("Sequence {seq_id} still waiting for memory...");
281+
// Safely break the loop to wait for the next iteration without dropping the request!
282+
break;
255283
}
256-
} else {
257-
warn!(
258-
"Sequence {seq_id} with length of {num_tokens} tokens is too long and exceeds KV cache size. \
259-
To fix, increase the maximum sequence length for the KV cache, for example with \
260-
`--max-seq-len`/ `max_seq_len` in automatic device mapping parameters.",
261-
);
262-
get_mut_arcmutex!(seq).set_state(SequenceState::FinishedIgnored);
263-
did_ignore = true;
264284
}
265285
} else {
266286
break;
@@ -334,6 +354,7 @@ impl PagedAttentionScheduler {
334354
self.sort_running_by_priority_fcfs();
335355

336356
let mut running: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
357+
let mut deferred_running: VecDeque<Arc<Mutex<Sequence>>> = VecDeque::new();
337358
while !self.running.is_empty() {
338359
let seq = self.running.pop_front().unwrap();
339360
let mut finished_with_break = false;
@@ -367,16 +388,12 @@ impl PagedAttentionScheduler {
367388
{
368389
running.push_back(seq);
369390
} else {
370-
self.running.push_back(seq);
391+
deferred_running.push_back(seq);
371392
}
372393
}
373394
}
374395
self.running = running;
375-
376-
// Bucket running completions by sequence length
377-
let running_for_bucket = std::mem::take(&mut self.running);
378-
let bucketed = self.bucket_and_preempt_sequences(running_for_bucket);
379-
self.running = bucketed;
396+
self.running.extend(deferred_running);
380397

381398
self.running
382399
.iter()
@@ -494,10 +511,10 @@ impl PagedAttentionScheduler {
494511
}
495512

496513
fn sort_running_by_priority_fcfs(&mut self) {
514+
// Sort oldest-first (true FCFS) — oldest sequences get priority for decode slots
497515
self.running
498516
.make_contiguous()
499517
.sort_by_key(|seq| get_mut_arcmutex!(seq).timestamp());
500-
self.running.make_contiguous().reverse();
501518
}
502519
}
503520

0 commit comments

Comments
 (0)