Skip to content

Commit 70cdb07

Browse files
authored
refactor(kv-cache): embed KvCacheAllocator in MemoryManager as allocator + test model skills. (#1301)
1 parent 8bcd28b commit 70cdb07

31 files changed

Lines changed: 1778 additions & 121 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ dist
77
.vscode
88
tmp/
99
requirements-musa.txt
10+
logs/

docs/CN/source/tutorial/api_server_args.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ APIServer 参数详解
2727
- ``running_max_req_size`` 为 3
2828
- ``batch_max_tokens`` 为 2048 (2k)
2929
- ``chunked_prefill_size`` 为 1024 (1k)
30-
- ``mem_fraction`` 为 0.85
3130

3231
.. option:: --host
3332

docs/EN/source/tutorial/api_server_args.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ Basic Configuration Parameters
2727
- ``running_max_req_size`` to 3
2828
- ``batch_max_tokens`` to 2048 (2k)
2929
- ``chunked_prefill_size`` to 1024 (1k)
30-
- ``mem_fraction`` to 0.85
3130

3231
.. option:: --host
3332

lightllm/common/basemodel/attention/nsa/flashmla_sparse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,17 @@ def _nsa_prefill_att(
7979
from sgl_kernel.flash_mla import flash_mla_sparse_fwd
8080

8181
nsa_dict = att_control.nsa_prefill_dict
82-
topk_indices = nsa_dict["topk_indices"]
82+
topk_mem_indices = nsa_dict["topk_mem_indices"]
8383
softmax_scale = nsa_dict["softmax_scale"]
8484
kv_lora_rank = nsa_dict["kv_lora_rank"]
8585

86-
if topk_indices.ndim == 2:
87-
topk_indices = topk_indices.unsqueeze(1)
86+
if topk_mem_indices.ndim == 2:
87+
topk_mem_indices = topk_mem_indices.unsqueeze(1)
8888

8989
mla_out, _, _ = flash_mla_sparse_fwd(
9090
q=q,
9191
kv=kv,
92-
indices=topk_indices,
92+
indices=topk_mem_indices,
9393
sm_scale=softmax_scale,
9494
d_v=kv_lora_rank,
9595
)

lightllm/common/kv_cache_mem_manager/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .allocator import KvCacheAllocator
12
from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager
23
from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
34
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
@@ -9,6 +10,7 @@
910
from .qwen3next_mem_manager import Qwen3NextMemManager
1011

1112
__all__ = [
13+
"KvCacheAllocator",
1214
"MemoryManager",
1315
"ReadOnlyStaticsMemoryManager",
1416
"PPLINT4KVMemoryManager",

lightllm/common/kv_cache_mem_manager/mem_manager.py

Lines changed: 12 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import torch
44
import torch.distributed as dist
55
import torch.multiprocessing as mp
6-
from typing import List, Union, Tuple, Any
6+
from typing import List, Tuple, Any, Union
77
from lightllm.server.pd_io_struct import KVMoveTask
88
from lightllm.utils.log_utils import init_logger
99
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
10+
from .allocator import KvCacheAllocator
1011
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
1112
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
1213
from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size
@@ -38,27 +39,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
3839
# profile the max total token num if the size is None
3940
self.profile_size(mem_fraction)
4041

41-
self.mem_state = torch.arange(
42-
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
43-
)
44-
self._mem_state_return = torch.arange(
45-
0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
46-
)
47-
self._return_start = 0
48-
self.mark_start = 0
49-
self.mark_end = self.size
50-
51-
self.can_use_mem_size = self.size
52-
53-
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
54-
from lightllm.utils.envs_utils import get_unique_server_name
42+
self.allocator = KvCacheAllocator(self.size)
5543

56-
rank_in_node = get_current_rank_in_node()
57-
self.shared_can_use_token_num = SharedInt(
58-
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
59-
)
60-
61-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
6244
self._init_buffers(
6345
self.size,
6446
dtype,
@@ -83,9 +65,10 @@ def profile_size(self, mem_fraction):
8365
if self.size is not None:
8466
return
8567

68+
torch.cuda.empty_cache()
8669
world_size = dist.get_world_size()
87-
total_memory = get_total_gpu_memory()
88-
available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction)
70+
71+
available_memory = get_available_gpu_memory(world_size) * mem_fraction
8972
cell_size = self.get_cell_size()
9073
self.size = int(available_memory * 1024 ** 3 / cell_size)
9174
if world_size > 1:
@@ -338,57 +321,13 @@ def _free_buffers(self):
338321
self.kv_buffer = None
339322

340323
def alloc(self, need_size) -> torch.Tensor:
341-
if need_size > self.mark_end - self.mark_start:
342-
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
343-
assert False, "error alloc state"
344-
345-
start = self.mark_start
346-
end = self.mark_start + need_size
347-
self.mark_start += need_size
348-
349-
self.can_use_mem_size -= need_size
350-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
351-
352-
# 利用缓冲区返回,避免异步情况下的内存竞争
353-
if self._return_start + need_size > self._mem_state_return.shape[0]:
354-
self._return_start = 0
355-
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
356-
ans.copy_(self.mem_state[start:end])
357-
self._return_start += need_size
358-
return ans
359-
360-
def free(self, free_index: Union[torch.Tensor, List[int]]):
361-
"""_summary_
362-
363-
Args:
364-
free_index (torch.Tensor): _description_
365-
"""
324+
return self.allocator.alloc(need_size)
366325

367-
end = self.mark_start
368-
start = self.mark_start - len(free_index)
369-
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
370-
371-
if isinstance(free_index, list):
372-
self.mem_state.numpy()[start:end] = free_index
373-
else:
374-
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
375-
self.mem_state[start:end] = free_index
376-
377-
self.mark_start -= len(free_index)
378-
379-
self.can_use_mem_size += len(free_index)
380-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
381-
382-
if self.can_use_mem_size == len(self.mem_state):
383-
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
384-
return
326+
def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
327+
self.allocator.free(free_index)
385328

386329
def free_all(self):
387-
self.can_use_mem_size = len(self.mem_state)
388-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
389-
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
390-
self.mark_start = 0
391-
self.mark_end = len(self.mem_state)
330+
self.allocator.free_all()
392331

393332
def resize_mem(self, new_size):
394333
"""
@@ -401,13 +340,8 @@ def resize_mem(self, new_size):
401340
layer_num = self.layer_num
402341

403342
self.size = new_size
404-
self.mem_state = torch.arange(
405-
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
406-
)
407-
self.mark_start = 0
408-
self.mark_end = self.size
409-
self.can_use_mem_size = self.size
410-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
343+
self.allocator.resize(new_size)
344+
self.HOLD_TOKEN_MEMINDEX = self.size
411345
self._free_buffers()
412346
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
413347
return

lightllm/server/api_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
145145
parser.add_argument(
146146
"--mem_fraction",
147147
type=float,
148-
default=0.9,
148+
default=0.8,
149149
help="""Memory usage ratio, default is 0.9, you can specify a smaller value if OOM occurs at runtime.
150150
If max_total_token_num is not specified, it will be calculated automatically based on this value.""",
151151
)

lightllm/server/api_openai.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,13 @@ def _get_history_tool_calls_cnt(request: ChatCompletionRequest) -> int:
153153
return idx
154154

155155

156-
def _get_reasoning_from_request(request: ChatCompletionRequest) -> bool:
157-
"""Judge whether the request needs reasoning"""
156+
def _is_force_thinking_mode(request: ChatCompletionRequest) -> bool:
157+
"""Whether this request uses forced thinking / reasoning (parser + template)."""
158+
from .build_prompt import tokenizer_supports_force_thinking
159+
160+
if not tokenizer_supports_force_thinking():
161+
return False
162+
158163
reasoning_parser = get_env_start_args().reasoning_parser
159164
if not reasoning_parser:
160165
return False
@@ -175,7 +180,7 @@ def _process_reasoning_stream(
175180
) -> tuple[Optional[str], str]:
176181
"""Process reasoning content in streaming response"""
177182
if index not in reasoning_parser_dict:
178-
request_enable_reasoning = _get_reasoning_from_request(request)
183+
request_enable_reasoning = _is_force_thinking_mode(request)
179184
reasoning_parser_dict[index] = ReasoningParser(
180185
get_env_start_args().reasoning_parser,
181186
request.stream_reasoning,
@@ -376,7 +381,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
376381
reasoning_text = None
377382
reasoning_parser = get_env_start_args().reasoning_parser
378383
if reasoning_parser:
379-
request_enable_reasoning = _get_reasoning_from_request(request)
384+
request_enable_reasoning = _is_force_thinking_mode(request)
380385
try:
381386
parser = ReasoningParser(
382387
model_type=reasoning_parser,

lightllm/server/api_start.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,14 @@ def normal_or_p_d_start(args):
127127

128128
# performance_mode 参数处理
129129
if args.performance_mode == "personal":
130-
args.running_max_req_size = 3
130+
args.running_max_req_size = 6
131131
args.batch_max_tokens = 2048
132132
args.chunked_prefill_size = 1024
133-
if args.mem_fraction > 0.82:
134-
args.mem_fraction = 0.82
135-
args.graph_max_batch_size = 32
133+
args.embed_cache_storage_size = 0.8
134+
args.graph_max_batch_size = 6
136135
logger.info(
137136
f"performance_mode is personal, set running_max_req_size to 3,"
138-
f"batch_max_tokens to 2048, chunked_prefill_size to 1024, mem_fraction to 0.82,"
137+
f"batch_max_tokens to 2048, chunked_prefill_size to 1024,"
139138
f"graph_max_batch_size to 32"
140139
)
141140

lightllm/server/build_prompt.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from lightllm.server.tokenizer import get_tokenizer
44
from lightllm.utils.log_utils import init_logger
5+
from functools import lru_cache
56

67
logger = init_logger(__name__)
78

@@ -45,6 +46,32 @@ def init_tokenizer(args):
4546
return
4647

4748

49+
@lru_cache(maxsize=1)
50+
def tokenizer_supports_force_thinking() -> bool:
51+
"""Whether this tokenizer supports thinking / reasoning."""
52+
53+
assert tokenizer is not None
54+
55+
try:
56+
ans = "thinking" in tokenizer.chat_template or "enable_thinking" in tokenizer.chat_template
57+
logger.debug(f"chat_template: {tokenizer.chat_template}")
58+
logger.info(f"tokenizer_supports_force_thinking : {ans}")
59+
return ans
60+
except:
61+
pass
62+
63+
try:
64+
ans = "thinking" in tokenizer.tokenizer.chat_template or "enable_thinking" in tokenizer.tokenizer.chat_template
65+
logger.debug(f"tokenizer.tokenizer.chat_template: {tokenizer.tokenizer.chat_template}")
66+
logger.info(f"tokenizer_supports_force_thinking : {ans}")
67+
return ans
68+
except:
69+
pass
70+
71+
logger.info("tokenizer_supports_force_thinking : False")
72+
return False
73+
74+
4875
def _normalize_tool_call_arguments(messages: list) -> None:
4976
# Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility
5077
# Qwen35's chat template expects arguments to be a dict (uses |items filter)
@@ -94,6 +121,19 @@ async def build_prompt(request, tools) -> str:
94121
if request.chat_template_kwargs:
95122
kwargs.update(request.chat_template_kwargs)
96123

124+
# 修复一些parser类型是默认打开thinking,但是 tokenizer有时候不知道打开了thinking。导致
125+
# 构建的reasoning parser 和 tokenizer 的行为不对齐导致的问题。
126+
from .api_openai import _is_force_thinking_mode
127+
128+
thinking = _is_force_thinking_mode(request)
129+
130+
kwargs["thinking"] = thinking
131+
kwargs["enable_thinking"] = thinking
132+
133+
# TODO thinking 模式应该是3种,一种是强制思考,一种是强制不思考,一种是模型自己决定的自适应
134+
# 的思考模式。当前的代码只是实现了强制思考和强制不思考两种模式。后续要根据模型的情况,从tokenizer
135+
# 上判断能支持的思考模式种类,再进行设置,才能具备更完备的处理。
136+
97137
try:
98138
input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools)
99139
except BaseException as e:

0 commit comments

Comments
 (0)