33import torch
44import torch .distributed as dist
55import torch .multiprocessing as mp
6- from typing import List , Union , Tuple , Any
6+ from typing import List , Tuple , Any , Union
77from lightllm .server .pd_io_struct import KVMoveTask
88from lightllm .utils .log_utils import init_logger
99from lightllm .server .router .dynamic_prompt .shared_arr import SharedInt
10+ from .allocator import KvCacheAllocator
1011from lightllm .utils .profile_max_tokens import get_available_gpu_memory , get_total_gpu_memory
1112from lightllm .common .kv_trans_kernel .kv_trans import kv_trans
1213from lightllm .utils .dist_utils import get_current_rank_in_node , get_node_world_size
@@ -38,27 +39,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
3839 # profile the max total token num if the size is None
3940 self .profile_size (mem_fraction )
4041
41- self .mem_state = torch .arange (
42- 0 , self .size , dtype = torch .int32 , device = "cpu" , requires_grad = False , pin_memory = True
43- )
44- self ._mem_state_return = torch .arange (
45- 0 , self .size * 3 , dtype = torch .int32 , device = "cpu" , requires_grad = False , pin_memory = True
46- )
47- self ._return_start = 0
48- self .mark_start = 0
49- self .mark_end = self .size
50-
51- self .can_use_mem_size = self .size
52-
53- # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
54- from lightllm .utils .envs_utils import get_unique_server_name
42+ self .allocator = KvCacheAllocator (self .size )
5543
56- rank_in_node = get_current_rank_in_node ()
57- self .shared_can_use_token_num = SharedInt (
58- f"{ get_unique_server_name ()} _mem_manger_can_use_token_num_{ rank_in_node } "
59- )
60-
61- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
6244 self ._init_buffers (
6345 self .size ,
6446 dtype ,
@@ -83,9 +65,10 @@ def profile_size(self, mem_fraction):
8365 if self .size is not None :
8466 return
8567
68+ torch .cuda .empty_cache ()
8669 world_size = dist .get_world_size ()
87- total_memory = get_total_gpu_memory ()
88- available_memory = get_available_gpu_memory (world_size ) - total_memory * ( 1 - mem_fraction )
70+
71+ available_memory = get_available_gpu_memory (world_size ) * mem_fraction
8972 cell_size = self .get_cell_size ()
9073 self .size = int (available_memory * 1024 ** 3 / cell_size )
9174 if world_size > 1 :
@@ -338,57 +321,13 @@ def _free_buffers(self):
338321 self .kv_buffer = None
339322
340323 def alloc (self , need_size ) -> torch .Tensor :
341- if need_size > self .mark_end - self .mark_start :
342- logger .error (f"warn no enough cache need_size { need_size } left_size { self .can_use_mem_size } " )
343- assert False , "error alloc state"
344-
345- start = self .mark_start
346- end = self .mark_start + need_size
347- self .mark_start += need_size
348-
349- self .can_use_mem_size -= need_size
350- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
351-
352- # 利用缓冲区返回,避免异步情况下的内存竞争
353- if self ._return_start + need_size > self ._mem_state_return .shape [0 ]:
354- self ._return_start = 0
355- ans = self ._mem_state_return [self ._return_start : self ._return_start + need_size ]
356- ans .copy_ (self .mem_state [start :end ])
357- self ._return_start += need_size
358- return ans
359-
360- def free (self , free_index : Union [torch .Tensor , List [int ]]):
361- """_summary_
362-
363- Args:
364- free_index (torch.Tensor): _description_
365- """
324+ return self .allocator .alloc (need_size )
366325
367- end = self .mark_start
368- start = self .mark_start - len (free_index )
369- assert start >= 0 , f"error free state start: { self .mark_start } free len { len (free_index )} "
370-
371- if isinstance (free_index , list ):
372- self .mem_state .numpy ()[start :end ] = free_index
373- else :
374- # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
375- self .mem_state [start :end ] = free_index
376-
377- self .mark_start -= len (free_index )
378-
379- self .can_use_mem_size += len (free_index )
380- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
381-
382- if self .can_use_mem_size == len (self .mem_state ):
383- logger .debug (f"freed all gpu mem size { self .can_use_mem_size } " )
384- return
326+ def free (self , free_index : Union [torch .Tensor , List [int ]]) -> None :
327+ self .allocator .free (free_index )
385328
386329 def free_all (self ):
387- self .can_use_mem_size = len (self .mem_state )
388- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
389- self .mem_state .numpy ()[:] = list (range (0 , len (self .mem_state )))
390- self .mark_start = 0
391- self .mark_end = len (self .mem_state )
330+ self .allocator .free_all ()
392331
393332 def resize_mem (self , new_size ):
394333 """
@@ -401,13 +340,8 @@ def resize_mem(self, new_size):
401340 layer_num = self .layer_num
402341
403342 self .size = new_size
404- self .mem_state = torch .arange (
405- 0 , self .size , dtype = torch .int32 , device = "cpu" , requires_grad = False , pin_memory = True
406- )
407- self .mark_start = 0
408- self .mark_end = self .size
409- self .can_use_mem_size = self .size
410- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
343+ self .allocator .resize (new_size )
344+ self .HOLD_TOKEN_MEMINDEX = self .size
411345 self ._free_buffers ()
412346 self ._init_buffers (size , dtype , head_num , head_dim , layer_num )
413347 return
0 commit comments