forked from SRSWTI/bodega-inference-engine
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_http_concurrency.py
More file actions
516 lines (443 loc) · 17.9 KB
/
benchmark_http_concurrency.py
File metadata and controls
516 lines (443 loc) · 17.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
#!/usr/bin/env python3
"""
HTTP Concurrency Benchmark for bodega_mlx_engine
================================================
Benchmarks the running server (localhost:44468) with different concurrency levels
to measure continuous batching benefits. Tests max_concurrency 8, 16, 32 with
the same total number of queries for fair comparison.
Usage:
# Benchmark 90m model at concurrency 8, 16, 32 (default)
python scripts/benchmark_http_concurrency.py --base-url http://localhost:44468
# Custom model and query count
python scripts/benchmark_http_concurrency.py \
--model bodega-orion-0.6b \
--num-queries 32 \
--max-tokens 128
# Compare with sequential (requires reloading model without CB first)
python scripts/benchmark_http_concurrency.py --compare-sequential
Prerequisites:
- Server running with bodega-orion-0.6b loaded (with continuous_batching: true)
- For clean benchmark: restart with config_benchmark_90m.yaml
"""
from __future__ import annotations
import argparse
import asyncio
import json
import statistics
import sys
import time
from dataclasses import dataclass, field
from typing import Any
import os
import httpx
# Test prompts (varied lengths for realistic load)
PROMPTS = [
"Hello, how are you?",
"What is 2+2?",
"Say hello in Spanish.",
"What is the capital of France and why is it historically significant?",
"Write a Python function to calculate fibonacci numbers with memoization.",
"Explain the difference between a list and a tuple in Python.",
(
"Explain quantum computing in detail, covering: qubits, superposition, "
"entanglement, potential applications in cryptography and drug discovery."
),
(
"Write a comprehensive guide to building a production REST API in Python. "
"Include: project structure, routing, authentication with JWT."
),
]
@dataclass
class RequestResult:
prompt_tokens: int = 0
completion_tokens: int = 0
ttft_ms: float = 0.0
total_time_s: float = 0.0
tps: float = 0.0
error: str | None = None
@dataclass
class ConcurrencyResult:
concurrency: int
num_requests: int
total_wall_time_s: float
results: list[RequestResult] = field(default_factory=list)
@property
def successful(self) -> list[RequestResult]:
return [r for r in self.results if r.error is None]
@property
def throughput_tps(self) -> float:
total_out = sum(r.completion_tokens for r in self.successful)
return total_out / self.total_wall_time_s if self.total_wall_time_s > 0 else 0
@property
def mean_ttft_ms(self) -> float:
ttfts = [r.ttft_ms for r in self.successful]
return statistics.mean(ttfts) if ttfts else 0
@property
def p95_ttft_ms(self) -> float:
ttfts = sorted(r.ttft_ms for r in self.successful)
if not ttfts:
return 0
idx = min(int(len(ttfts) * 0.95), len(ttfts) - 1)
return ttfts[idx]
@property
def total_completion_tokens(self) -> int:
return sum(r.completion_tokens for r in self.successful)
async def run_one_request(
client: httpx.AsyncClient,
base_url: str,
model: str,
prompt: str,
max_tokens: int,
) -> RequestResult:
"""Run a single non-streaming request and measure TTFT from response time."""
url = f"{base_url.rstrip('/')}/v1/chat/completions"
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"stream": False,
}
t0 = time.perf_counter()
try:
resp = await client.post(url, json=payload, timeout=120.0)
total_time = time.perf_counter() - t0
if resp.status_code != 200:
return RequestResult(
error=f"HTTP {resp.status_code}: {resp.text[:80]}",
total_time_s=total_time,
)
data = resp.json()
choices = data.get("choices", [])
usage = data.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
# For non-streaming we don't get per-token TTFT; use total_time as proxy
# (first token arrives with full response in non-streaming)
ttft_ms = total_time * 1000 # Approximate for non-streaming
gen_time = total_time
tps = (completion_tokens - 1) / gen_time if gen_time > 0 and completion_tokens > 1 else 0
return RequestResult(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
ttft_ms=ttft_ms,
total_time_s=total_time,
tps=tps,
)
except Exception as e:
return RequestResult(
error=str(e),
total_time_s=time.perf_counter() - t0,
)
async def run_streaming_request(
client: httpx.AsyncClient,
base_url: str,
model: str,
prompt: str,
max_tokens: int,
) -> RequestResult:
"""Run a streaming request to measure actual TTFT."""
url = f"{base_url.rstrip('/')}/v1/chat/completions"
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"stream": True,
}
t0 = time.perf_counter()
ttft = None
prompt_tokens = 0
completion_tokens = 0
last_content = ""
try:
async with client.stream("POST", url, json=payload, timeout=120.0) as resp:
if resp.status_code != 200:
text = await resp.aread()
return RequestResult(
error=f"HTTP {resp.status_code}: {text.decode()[:80]}",
total_time_s=time.perf_counter() - t0,
)
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
content = delta.get("content")
if content:
if ttft is None:
ttft = (time.perf_counter() - t0) * 1000
last_content += content
completion_tokens += 1
usage = data.get("usage")
if usage:
if usage.get("prompt_tokens"):
prompt_tokens = usage["prompt_tokens"]
if usage.get("completion_tokens"):
completion_tokens = usage["completion_tokens"]
if choices[0].get("finish_reason"):
break
total_time = time.perf_counter() - t0
if ttft is None:
ttft = total_time * 1000
# Approximate completion tokens if not in usage
if completion_tokens == 0 and last_content:
completion_tokens = max(1, len(last_content.split()))
gen_time = total_time - (ttft / 1000) if total_time > ttft / 1000 else total_time
tps = (completion_tokens - 1) / gen_time if gen_time > 0 and completion_tokens > 1 else 0
return RequestResult(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
ttft_ms=ttft,
total_time_s=total_time,
tps=tps,
)
except Exception as e:
return RequestResult(
error=str(e),
total_time_s=time.perf_counter() - t0,
)
class IncompatibleModelError(Exception):
pass
def open_mactop_window():
"""Opens mactop in a new Terminal window and focuses it via osascript."""
script = '''tell application "Terminal"
do script "mactop"
activate
end tell'''
os.system(f"osascript -e '{script}' >/dev/null 2>&1")
def get_model_type(model_path: str) -> str:
"""Detect model_type from config.json (lm | multimodal | whisper | embeddings)."""
from detect_model_type import detect_model_type
return detect_model_type(model_path)
async def _check_mlx_tag(client: httpx.AsyncClient, model_path: str) -> None:
"""Verify model has MLX tag on HuggingFace. Raises IncompatibleModelError if not."""
if "/" not in model_path:
return
try:
resp = await client.get(f"https://huggingface.co/api/models/{model_path}", timeout=6.0)
if resp.status_code == 200:
tags = [t.lower() for t in resp.json().get("tags", [])]
if not any("mlx" in t for t in tags):
raise IncompatibleModelError(
f"Model '{model_path}' does not have an MLX tag on HuggingFace. "
f"Only MLX-format models are compatible with the Bodega Inference Engine."
)
except IncompatibleModelError:
raise
except Exception:
pass
async def manage_model(client: httpx.AsyncClient, base_url: str, action: str, model_path: str, model_id: str) -> bool:
"""Helper to dynamically load/unload the model. Uses config.json for model_type — no retry."""
full_url = base_url.rstrip("/")
if action == "load":
print(f" [+] Loading model {model_path} into {model_id}...")
try:
await _check_mlx_tag(client, model_path)
except IncompatibleModelError as e:
print(f" [!] {e}")
return False
mtype = get_model_type(model_path)
print(f" [->] Detected model_type from config.json: {mtype}")
payload = {
"model_path": model_path,
"model_id": model_id,
"model_type": mtype,
"context_length": 8192,
"continuous_batching": True,
"cb_max_num_seqs": 128
}
resp = await client.post(f"{full_url}/v1/admin/load-model", json=payload, timeout=120.0)
if resp.status_code == 409:
print(f" [✓] Model already loaded (as {mtype}). Continuing.")
if mtype == "multimodal":
print(" [!] Note: Continuous batching for 'multimodal' models is coming soon to Bodega.\n"
" The engine currently falls back to sequential execution for vision models.", flush=True)
choice = input(" Continue anyway? [y/N]: ")
if choice.lower() not in ['y', 'yes']:
return False
return True
if resp.status_code in [200, 201]:
print(f" [✓] Loaded as {mtype}.")
if mtype == "multimodal":
print(" [!] Note: Continuous batching for 'multimodal' models is coming soon to Bodega.\n"
" The engine currently falls back to sequential execution for vision models.", flush=True)
choice = input(" Continue anyway? [y/N]: ")
if choice.lower() not in ['y', 'yes']:
return False
return True
print(f" [!] Load failed: {resp.status_code} {resp.text}")
return False
elif action == "unload":
print(f" [-] Unloading model {model_id}...")
resp = await client.delete(f"{full_url}/v1/admin/unload-model/{model_id}", timeout=30.0)
if resp.status_code not in [200, 204]:
print(f" [!] Unload failed: {resp.status_code} {resp.text}")
return resp.status_code in [200, 204]
async def benchmark_concurrency(
base_url: str,
model: str,
concurrency: int,
num_queries: int,
max_tokens: int,
stream: bool = True,
) -> ConcurrencyResult:
"""Run num_queries with at most `concurrency` in-flight at once."""
prompts = (PROMPTS * ((num_queries // len(PROMPTS)) + 1))[:num_queries]
run_fn = run_streaming_request if stream else run_one_request
sem = asyncio.Semaphore(concurrency)
async def run_with_sem(prompt: str) -> RequestResult:
async with sem:
async with httpx.AsyncClient() as client:
return await run_fn(client, base_url, model, prompt, max_tokens)
t_start = time.perf_counter()
tasks = [run_with_sem(p) for p in prompts]
results = await asyncio.gather(*tasks)
wall_time = time.perf_counter() - t_start
return ConcurrencyResult(
concurrency=concurrency,
num_requests=len(prompts),
total_wall_time_s=wall_time,
results=list(results),
)
def print_result(r: ConcurrencyResult) -> None:
ok = r.successful
failed = len(r.results) - len(ok)
print(f"\n Concurrency {r.concurrency} ({len(ok)}/{r.num_requests} succeeded, {failed} failed)")
print(f" Wall time: {r.total_wall_time_s:.2f}s")
print(f" Throughput: {r.throughput_tps:.1f} tok/s (system)")
print(f" Mean TTFT: {r.mean_ttft_ms:.0f}ms")
print(f" P95 TTFT: {r.p95_ttft_ms:.0f}ms")
print(f" Total tokens: {r.total_completion_tokens:,}")
if failed > 0:
first_error = next(res.error for res in r.results if res.error)
print(f" [red]First error: {first_error}[/red]")
def print_comparison_table(results: list[ConcurrencyResult]) -> None:
print("\n" + "=" * 70)
print(" CONTINUOUS BATCHING BENCHMARK — Concurrency Comparison")
print("=" * 70)
print(f" {'Concurrency':<14} {'Wall Time':<12} {'Throughput':<14} {'Mean TTFT':<12} {'P95 TTFT':<10}")
print("-" * 70)
for r in results:
print(f" {r.concurrency:<14} {r.total_wall_time_s:>9.2f}s {r.throughput_tps:>10.1f} tok/s "
f"{r.mean_ttft_ms:>8.0f}ms {r.p95_ttft_ms:>6.0f}ms")
print("=" * 70)
# Show scaling benefit
if len(results) >= 2:
base = results[0]
print("\n Scaling vs concurrency 1 (approximate):")
for r in results[1:]:
speedup = base.throughput_tps / r.throughput_tps if r.throughput_tps > 0 else 0
# Actually we want: as concurrency increases, throughput should increase with CB
tp_gain = r.throughput_tps / base.throughput_tps if base.throughput_tps > 0 else 0
print(f" Concurrency {r.concurrency}: {tp_gain:.2f}x throughput vs {base.concurrency}")
async def main() -> None:
parser = argparse.ArgumentParser(
description="HTTP concurrency benchmark for bodega_mlx_engine",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--base-url",
default="http://localhost:44468",
help="Server base URL (default: http://localhost:44468)",
)
parser.add_argument(
"--model",
default="srswti/bodega-orion-0.6b",
help="Model ID (default: srswti/bodega-orion-0.6b)",
)
parser.add_argument(
"--concurrencies",
default="8,16,32",
help="Comma-separated concurrency levels (default: 8,16,32)",
)
parser.add_argument(
"--num-queries",
type=int,
default=32,
help="Total number of queries per run (default: 32)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=128,
help="Max tokens per request (default: 128)",
)
parser.add_argument(
"--no-stream",
action="store_true",
help="Use non-streaming requests (TTFT less accurate)",
)
parser.add_argument(
"--concurrency",
type=int,
default=None,
help="Single concurrency level override (overrides --concurrencies, used for multimodal sequential mode)",
)
args = parser.parse_args()
if args.concurrency is not None:
concurrencies = [args.concurrency]
else:
concurrencies = [int(x.strip()) for x in args.concurrencies.split(",") if x.strip()]
print("==" * 35)
print(" bodega_mlx_engine — HTTP Concurrency Benchmark")
print("==" * 35)
print(f" Base URL: {args.base_url}")
print(f" Model: {args.model}")
print(f" Concurrency: {concurrencies}")
print(f" Num queries: {args.num_queries} (same for all runs)")
print(f" Max tokens: {args.max_tokens}")
print(f" Streaming: {not args.no_stream}")
print()
print(" [Telemetry] Opening mactop in a new Terminal window...")
open_mactop_window()
print()
# Health check
try:
async with httpx.AsyncClient() as client:
r = await client.get(f"{args.base_url.rstrip('/')}/health", timeout=5.0)
if r.status_code != 200:
print("⚠ Health check returned non-200. Proceeding anyway.")
else:
print("✓ Server health OK")
except Exception as e:
print(f"✗ Cannot reach server at {args.base_url}: {e}")
sys.exit(1)
results: list[ConcurrencyResult] = []
# Dynamically load the model via admin API
async with httpx.AsyncClient() as client:
success = await manage_model(client, args.base_url, "load", args.model, args.model)
if not success:
print("✗ Failed to load model. Exiting.")
sys.exit(1)
for concurrency in concurrencies:
print(f"\n--- Running concurrency {concurrency} ---")
res = await benchmark_concurrency(
base_url=args.base_url,
model=args.model,
concurrency=concurrency,
num_queries=args.num_queries,
max_tokens=args.max_tokens,
stream=not args.no_stream,
)
results.append(res)
print_result(res)
print_comparison_table(results)
# Unload after testing
async with httpx.AsyncClient() as client:
await manage_model(client, args.base_url, "unload", args.model, args.model)
if __name__ == "__main__":
asyncio.run(main())