Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 118 additions & 19 deletions responses_api_models/local_vllm_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,18 +357,118 @@ def new_create_dp_placement_groups(vllm_config):
"The actual data-parallel-size-local will be auto determined."
)

for _ in range(dp_size - 1):
bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}]

pg_name = f"{self.server_name}_dp_rank_{len(placement_groups)}"
pg = ray.util.placement_group(
name=pg_name,
strategy=placement_strategy,
bundles=bundles,
)

placement_groups.append(pg)
local_dp_ranks.append(0)
# Mirror: vllm/v1/engine/utils.py CoreEngineActorManager.create_dp_placement_groups
# Upstream uses one loop for strict/fill/span; we split span out. NeMo Gym diffs
# from upstream only inside START/END blocks (same style as earlier patches).
if pack_strategy == "span":
"""
START NeMo Gym: span with pre-created head PG — simplified PG loop vs upstream
(upstream interleaves span/collected_bundles with empty initial PG list).
"""
for _ in range(dp_size - 1):
bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}]
pg_name = f"{self.server_name}_dp_rank_{len(placement_groups)}"
pg = ray.util.placement_group(
name=pg_name,
strategy=placement_strategy,
bundles=bundles,
)
ray.get(pg.ready())
placement_groups.append(pg)
local_dp_ranks.append(0)
"""
END NeMo Gym: span with pre-created head PG — simplified PG loop vs upstream
"""
else:
# strict/fill only (span handled above). Body parallels
# vllm/v1/engine/utils.py create_dp_placement_groups for node walk + inner loop.
for node_resources in nodes:
"""
START NeMo Gym: stop once head PG + new PGs reach dp_size
"""
if len(placement_groups) == dp_size:
break
"""
END NeMo Gym: stop once head PG + new PGs reach dp_size
"""
node_ip_keys = [
key for key in node_resources if key != "node:__internal_head__" and key.startswith("node:")
]
assert len(node_ip_keys) == 1, (
f"Zero or multiple node IP keys found in node resources: {node_ip_keys}"
)
node_ip_key = node_ip_keys[0]
node_ip = node_ip_key.split(":")[1]

n_device_on_node = int(node_resources.get(device_str, 0))
dp_size_available = n_device_on_node // world_size

if node_ip == dp_master_ip:
if dp_size_available < dp_size_local:
raise ValueError(
f"Not enough resources to allocate {dp_size_local} DP ranks "
f"on DP master node {dp_master_ip}, possible to fit "
f"{dp_size_available} DP ranks."
)
dp_size_to_allocate = dp_size_local
elif pack_strategy == "strict":
if dp_size_available < dp_size_local:
logger.info(
"Skipping node %s as %s DP ranks could not fit, possible to fit %s DP ranks",
node_ip,
dp_size_local,
dp_size_available,
)
continue
dp_size_to_allocate = dp_size_local
else:
# for "fill" (and upstream "span"; span not in this branch)
# we always take everything that's available
dp_size_to_allocate = dp_size_available

"""
START NeMo Gym: pre-created head PG is rank 0 on master; first inner slot is i=1
Upstream (same file): for i in range(dp_size_to_allocate):
"""
if node_ip == dp_master_ip and len(placement_groups) == 1:
dp_rank_index_range = range(1, dp_size_to_allocate)
else:
dp_rank_index_range = range(dp_size_to_allocate)
"""
END NeMo Gym: pre-created head PG is rank 0 on master; first inner slot is i=1
"""

for i in dp_rank_index_range:
device_bundle = [{device_str: 1.0, "node:" + node_ip: 0.001}]
bundles = device_bundle * world_size + [{"CPU": 1.0}]

"""
START NeMo Gym: per-server PG names; wait for PG before scheduling next
Upstream (same file): name=f"dp_rank_{len(placement_groups)}"
then append without ray.get(pg.ready()).
"""
pg = ray.util.placement_group(
name=f"{self.server_name}_dp_rank_{len(placement_groups)}",
strategy=placement_strategy,
bundles=bundles,
)
ray.get(pg.ready())
"""
END NeMo Gym: per-server PG names; wait for PG before scheduling next
"""
placement_groups.append(pg)
local_dp_ranks.append(i)
if len(placement_groups) == dp_size:
break

"""
START NeMo Gym: outer for-node loop exit when dp_size reached
"""
if len(placement_groups) == dp_size:
break
"""
END NeMo Gym: outer for-node loop exit when dp_size reached
"""

if len(placement_groups) < dp_size:
raise ValueError(
Expand Down Expand Up @@ -449,14 +549,13 @@ def _configure_vllm_serve(self) -> Tuple[Namespace, Dict[str, str]]:
# "Ray backend only works with data parallel size > 1!"
# )

# With our vLLM patches, this is no longer necessary for people to set.
server_args["data_parallel_size_local"] = 1

# TODO multi-node model instances still need to be properly supported
# We get a vLLM error: Exception: Error setting CUDA_VISIBLE_DEVICES: local range: [0, 16) base value: "0,1,2,3,4,5,6,7"
if env_vars.get("VLLM_RAY_DP_PACK_STRATEGY") == "span":
# Unset this flag since it's set by default using span
# Match upstream vLLM: data_parallel_size_local controls ranks per node for
# strict/fill (see create_dp_placement_groups). Default to 1 when unset.
pack = env_vars.get("VLLM_RAY_DP_PACK_STRATEGY")
if pack == "span":
server_args.pop("data_parallel_size_local", None)
elif server_args.get("data_parallel_size_local") is None:
server_args["data_parallel_size_local"] = 1

cli_env_setup()
parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible RESTful API server.")
Expand Down
Loading