Compute buffer and KV-cache aware layer distribution for multi-GPU inference#14484
Compute buffer and KV-cache aware layer distribution for multi-GPU inference#14484borebot wants to merge 2 commits intoggml-org:masterfrom
Conversation
…layer distribution for heterogeneous multi-GPU inference. Solves the problem of attemtping to run setups with different VRAM (e.g. 24GB cards with 6GB cards); previously layers were assigned without accounting for compute buffer, causing failure when one or more smaller GPUs could not hold the compute buffer. - Add requested_n_ctx parameter to llama_model_params - Implement 3-pass allocation algorithm accounting for compute buffers - Add device exclusion for insufficient memory (GPUs too small to allocate 1 layer + KV_cache + compute buffer excluded) - Add layer redistribution to make equitable use of included GPUs (may not be truly optimal)
|
This is a very welcome addition. As of now the tensor splitting with multiple GPU is right on the border of unusable and takes a huge amount of patience manually reloading, adjusting tensor splt and context size, trying again, crashing with mem out, adjust, try again, crash.... extremely frustrating. Also needed is the ability to specify a pre allocation on a per device basis so available device memory prior to waterfill can be adjusted to take into account loading a speculator or a mmproj. Related discussion in #13314. |
|
I think it's preferable to do dummy allocations on each device and to then iteratively adjust the runtime parameters, see #14067 . Previously the feedback from other maintainers was that duplicating the logic for memory use is unacceptable in terms of maintenance. |
I don't agree. It is possible to predict how much total context memory will be used by a model given the number of desired KV tokens . #10068 I use this approach with my downstream shell based model loading and inference platform and it works well. This prediction can be passed in as requested_n_ctx parameter and then the layer split will take into account KV memory for the layers being loaded which is exactly what is wanted to avoid loading ooms. |
|
For this addition I primarily had users new to llama.cpp in mind (like myself- I'm only about 4 days in), who would benefit from a plug-and-play replacement default that works out of the box. I'm new to the codebase so I'm sure I've missed some things; I'm happy to take feedback from maintainers on how I might better align with project goals and structure. |
|
I have 2x3090+2x3060 setup and my current solution to maximize VRAM usage is to use -ts with some strange values |
Compute buffer and KV-cache aware layer distribution for multi-GPU inference. Solves the problem of attempting to run setups with heterogeneous GPU VRAMs (e.g. 24GB cards with 6GB cards); previously layers were assigned without accounting for compute buffer, causing failure when one or more smaller GPUs could not hold the compute buffer.
Modifications include:
TESTING DETAILS:
Primary server node:
./llama.cpp/build/bin/llama-server -m ./llama.cpp/models/YOUR_LLM_MODEL.gguf --rpc WORKER_IP_1:PORT1,WORKER_IP_2:PORT2,WORKER_IP_3:PORT3 --host 0.0.0.0 --port LLAMACPP_SERVER_PORT -ngl NUMER_OF_LAYERS_TO_DISTRIBUTE_TO_GPUS -c DESIRED_CONTEXT_LENGTH
Worker node (run separately per GPU ID, even if on the same machine):
cd /YOUR_PATH_TO/llama.cpp && CUDA_VISIBLE_DEVICES=GPU_ID_NUMBER ./build/bin/rpc-server --host 0.0.0.0 --port WORKER_PORT
LLMs tested:
Devstral-Small-2505-Q5_K_M.gguf
DeepSeek-R1-Distill-Llama-70B-Q4_K_M.gguf
gemma-3-27b-it-q4_0.gguf
Architecture tested:
Primary server machine: Ubuntu 24.042 GPUs (NVIDIA RTX3090 (24GB) + NVIDIA RTX2060 (6GB))
Worker host machine: Proxmox + 2VMs:
Various -c context lengths tested for each LLM model; properly excludes small GPUs when KV cache and compute buffer don't fit. If Model + Context length is too large to fit in the setup, launch fails (as expected). Have not tested with offload to CPU + system RAM.