Skip to content
110 changes: 110 additions & 0 deletions dlinfer/framework/lmdeploy_ext/device/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,3 +980,113 @@ def get_assignment_batch(
cache_engine.CacheEngine = AscendCacheEngine
executor_base.CacheEngine = AscendCacheEngine
model_agent.CacheEngine = AscendCacheEngine


##### patch scheduler #####
##### workaround for uncompleted prefill_attention_with_kvcache #####
from lmdeploy.pytorch.paging.scheduler import Scheduler
from lmdeploy.pytorch.messages import MessageStatus, SchedulerSequence
from lmdeploy.messages import EventType

MapType = Dict[int, int]
SeqList = List[SchedulerSequence]


def _schedule_prefill_ascend(self, prealloc_size: int = 0):
"""Schedule prefill for Ascend devices with kv-cache constraints.

This function patches :meth:`Scheduler._schedule_prefill` to add extra
restrictions when batching requests that use the Ascend
``prefill_attention_with_kvcache`` kernel. The original CUDA-based
implementation can freely mix plain prefill (``num_new_tokens == 0``) and
decode (``num_new_tokens > 0``) requests in the same batch and allows
multiple sequences that rely on kv-cache optimizations to be scheduled
together.

On Ascend, ``prefill_attention_with_kvcache`` is currently not
feature-complete: mixing multiple kv-cache-optimized prefill/decode
requests in a single step may corrupt kv-cache state or exceed hardware
limitations. To work around this, the ``prefill_with_kvcache`` state and
the associated early-return conditions in this method enforce that:

* the original token-count and batch-size admission logic is preserved;
* at most one sequence that requires kv-cache-optimized prefill or decode
is admitted in a scheduling step; and
* once such a sequence is admitted, additional prefill/decode requests are
not mixed into the same step.

This workaround should be removed, and the upstream
:meth:`Scheduler._schedule_prefill` used as-is, once the Ascend
implementation of ``prefill_attention_with_kvcache`` supports fully mixed
prefill and decode batching without these constraints.
"""
max_batches = (
self.scheduler_config.max_batches - self.num_ready() - self.num_running()
)
eviction_helper = self.eviction_helper
swap_out_map: MapType = dict()
swap_in_map: MapType = dict()
copy_map: MapType = dict()
running: SeqList = []
token_count = 0

def _to_running(seq: SchedulerSequence):
"""To running."""
seq.state.activate()
running.append(seq)
nonlocal token_count
token_count += seq.num_token_ids

def __evict_for_seq(seq: SchedulerSequence, waiting):
"""Evict until can append."""
from itertools import chain

hanging = reversed(self.hanging)
waiting = reversed(waiting)
evictable = list(chain(hanging, waiting))
return eviction_helper.evict_for_seq(seq, evictable, prealloc_size)

def _reorder_waiting():
"""Reorder waiting."""
return sorted(self.waiting, key=lambda seq: seq.arrive_time)

num_waiting = self.seq_manager.num_sequences(MessageStatus.WAITING)
if len(running) >= max_batches or num_waiting == 0:
return running, swap_in_map, swap_out_map, copy_map

waiting = _reorder_waiting()
prefill_with_kvcache = True
while len(waiting) > 0 and len(running) < max_batches:
seq = waiting.pop(0)

if not prefill_with_kvcache and seq.num_new_tokens > 0:
break
prefill_with_kvcache = False if seq.num_new_tokens == 0 else True
Comment thread
jinminxi104 marked this conversation as resolved.

if (
len(running) > 0
and token_count + seq.num_token_ids
> self.cache_config.max_prefill_token_num
):
break

self.block_trie.match(seq)

if not __evict_for_seq(seq, waiting):
break

# allocate session memory
self.block_manager.allocate(seq, prealloc_size)
self.block_trie.allocate(seq)
if self.is_ssm:
self.state_manager.allocate(seq)
_to_running(seq)

seq.record_event(EventType.SCHEDULED)
if prefill_with_kvcache:
Comment thread
jinminxi104 marked this conversation as resolved.
break

return running, swap_in_map, swap_out_map, copy_map


Scheduler._schedule_prefill = _schedule_prefill_ascend