From 67ee43f67b9a74940d83be54ec89bcfd2554708e Mon Sep 17 00:00:00 2001 From: Cam Quilici Date: Mon, 20 Apr 2026 13:26:43 -0500 Subject: [PATCH 1/2] improve benchmark serving --- utils/bench_serving/backend_request_func.py | 619 ++++++++++---------- utils/bench_serving/benchmark_serving.py | 426 ++++++++++---- utils/bench_serving/benchmark_utils.py | 15 +- 3 files changed, 626 insertions(+), 434 deletions(-) diff --git a/utils/bench_serving/backend_request_func.py b/utils/bench_serving/backend_request_func.py index af030720e..3501f01d2 100644 --- a/utils/bench_serving/backend_request_func.py +++ b/utils/bench_serving/backend_request_func.py @@ -10,7 +10,6 @@ import aiohttp import huggingface_hub.constants -from tqdm.asyncio import tqdm from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -48,381 +47,361 @@ class RequestFuncOutput: async def async_request_tgi( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + session: aiohttp.ClientSession, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: - params = { - "best_of": request_func_input.best_of, - "max_new_tokens": request_func_input.output_len, - "do_sample": True, - "temperature": 0.01, # TGI does not accept 0.0 temperature. - "top_p": 0.99, # TGI does not accept 1.0 top_p. - "truncate": request_func_input.prompt_len, - # TGI does not accept ignore_eos flag. - } - payload = { - "inputs": request_func_input.prompt, - "parameters": params, - } - output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len - - ttft = 0.0 - st = time.perf_counter() - most_recent_timestamp = st - try: - async with session.post(url=api_url, json=payload) as response: - if response.status == 200: - async for chunk_bytes in response.content: - chunk_bytes = chunk_bytes.strip() - if not chunk_bytes: - continue - chunk_bytes = chunk_bytes.decode("utf-8") - - # NOTE: Sometimes TGI returns a ping response without - # any data, we should skip it. - if chunk_bytes.startswith(":"): - continue - chunk = chunk_bytes.removeprefix("data:") - - data = json.loads(chunk) - timestamp = time.perf_counter() - # First token - if ttft == 0.0: - ttft = time.perf_counter() - st - output.ttft = ttft - - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + params = { + "best_of": request_func_input.best_of, + "max_new_tokens": request_func_input.output_len, + "do_sample": True, + "temperature": 0.01, # TGI does not accept 0.0 temperature. + "top_p": 0.99, # TGI does not accept 1.0 top_p. + "truncate": request_func_input.prompt_len, + # TGI does not accept ignore_eos flag. + } + payload = { + "inputs": request_func_input.prompt, + "parameters": params, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk_bytes = chunk_bytes.decode("utf-8") + + # NOTE: Sometimes TGI returns a ping response without + # any data, we should skip it. + if chunk_bytes.startswith(":"): + continue + chunk = chunk_bytes.removeprefix("data:") + + data = json.loads(chunk) + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) - most_recent_timestamp = timestamp + most_recent_timestamp = timestamp - output.latency = most_recent_timestamp - st - output.success = True - output.generated_text = data["generated_text"] - else: - output.error = response.reason or "" - output.success = False - except Exception: - output.success = False - exc_info = sys.exc_info() - output.error = "".join(traceback.format_exception(*exc_info)) + output.latency = most_recent_timestamp - st + output.success = True + output.generated_text = data["generated_text"] + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) - if pbar: - pbar.update(1) - return output + return output async def async_request_trt_llm( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + session: aiohttp.ClientSession, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: - assert request_func_input.best_of == 1 - payload = { - "accumulate_tokens": True, - "text_input": request_func_input.prompt, - "temperature": 0.0, - "top_p": 1.0, - "max_tokens": request_func_input.output_len, - "stream": True, - } - if request_func_input.ignore_eos: - payload["min_length"] = request_func_input.output_len - output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len - - ttft = 0.0 - st = time.perf_counter() - most_recent_timestamp = st - try: - async with session.post(url=api_url, json=payload) as response: - if response.status == 200: - async for chunk_bytes in response.content: - chunk_bytes = chunk_bytes.strip() - if not chunk_bytes: - continue - - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data:") - - data = json.loads(chunk) - output.generated_text += data["text_output"] - timestamp = time.perf_counter() - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft - - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + assert request_func_input.best_of == 1 + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + } + if request_func_input.ignore_eos: + payload["min_length"] = request_func_input.output_len + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) - most_recent_timestamp = timestamp + most_recent_timestamp = timestamp - output.latency = most_recent_timestamp - st - output.success = True + output.latency = most_recent_timestamp - st + output.success = True - else: - output.error = response.reason or "" - output.success = False - except Exception: - output.success = False - exc_info = sys.exc_info() - output.error = "".join(traceback.format_exception(*exc_info)) + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) - if pbar: - pbar.update(1) - return output + return output async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + session: aiohttp.ClientSession, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: - assert request_func_input.best_of == 1 - - payload = { - "prompt": request_func_input.prompt, - "max_tokens": request_func_input.output_len, - "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. - "top_p": 1.0, - } - output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len - - # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024, - # will use 0 as placeholder. - # See https://github.com/microsoft/DeepSpeed-MII/pull/311 - output.ttft = 0 - - st = time.perf_counter() - try: - async with session.post(url=request_func_input.api_url, - json=payload) as response: - if response.status == 200: - parsed_resp = await response.json() - output.latency = time.perf_counter() - st - output.generated_text = parsed_resp["text"][0] - output.success = True - else: - output.error = response.reason or "" - output.success = False - except Exception: - output.success = False - exc_info = sys.exc_info() - output.error = "".join(traceback.format_exception(*exc_info)) + assert request_func_input.best_of == 1 + + payload = { + "prompt": request_func_input.prompt, + "max_tokens": request_func_input.output_len, + "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. + "top_p": 1.0, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024, + # will use 0 as placeholder. + # See https://github.com/microsoft/DeepSpeed-MII/pull/311 + output.ttft = 0 + + st = time.perf_counter() + try: + async with session.post(url=request_func_input.api_url, + json=payload) as response: + if response.status == 200: + parsed_resp = await response.json() + output.latency = time.perf_counter() - st + output.generated_text = parsed_resp["text"][0] + output.success = True + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) - if pbar: - pbar.update(1) - return output + return output async def async_request_openai_completions( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + session: aiohttp.ClientSession, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( ("completions", "profile") ), "OpenAI Completions API URL must end with 'completions' or 'profile'." - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: - payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, - "prompt": request_func_input.prompt, - "temperature": 0.0, - "best_of": request_func_input.best_of, - "max_tokens": request_func_input.output_len, - "logprobs": request_func_input.logprobs, - "stream": True, - "stream_options": { - "include_usage": True, - }, - } - if request_func_input.ignore_eos: - payload["ignore_eos"] = request_func_input.ignore_eos - if request_func_input.extra_body: - payload.update(request_func_input.extra_body) - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } - - output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len - - generated_text = "" - st = time.perf_counter() - most_recent_timestamp = st - try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: - if response.status == 200: - first_chunk_received = False - async for chunk_bytes in response.content: - chunk_bytes = chunk_bytes.strip() - if not chunk_bytes: - continue - - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") - if chunk != "[DONE]": - data = json.loads(chunk) - - # NOTE: Some completion API might have a last - # usage summary response without a token so we - # want to check a token was generated - if choices := data.get("choices"): - # Note that text could be empty here - # e.g. for special tokens - text = choices[0].get("text") - timestamp = time.perf_counter() - # First token - if not first_chunk_received: - first_chunk_received = True - ttft = time.perf_counter() - st - output.ttft = ttft - - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) - - most_recent_timestamp = timestamp - generated_text += text or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") - if first_chunk_received: - output.success = True - else: - output.success = False - output.error = ( - "Never received a valid chunk to calculate TTFT." - "This response will be marked as failed!") - output.generated_text = generated_text - output.latency = most_recent_timestamp - st + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": request_func_input.best_of, + "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + if first_chunk_received: + output.success = True else: - output.error = response.reason or "" output.success = False - except Exception: - output.success = False - exc_info = sys.exc_info() - output.error = "".join(traceback.format_exception(*exc_info)) + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!") + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) - if pbar: - pbar.update(1) return output async def async_request_openai_chat_completions( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + session: aiohttp.ClientSession, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( "chat/completions" ), "OpenAI Chat Completions API URL must end with 'chat/completions'." - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: - content = [{"type": "text", "text": request_func_input.prompt}] - if request_func_input.multi_modal_content: - content.append(request_func_input.multi_modal_content) - payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, - "messages": [ - { - "role": "user", - "content": content - }, - ], - "temperature": 0.0, - "max_completion_tokens": request_func_input.output_len, - "stream": True, - "stream_options": { - "include_usage": True, + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "messages": [ + { + "role": "user", + "content": content }, - } - if request_func_input.ignore_eos: - payload["ignore_eos"] = request_func_input.ignore_eos - if request_func_input.extra_body: - payload.update(request_func_input.extra_body) - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - } - - output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len - - generated_text = "" - ttft = 0.0 - st = time.perf_counter() - most_recent_timestamp = st - try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: - if response.status == 200: - async for chunk_bytes in response.content: - chunk_bytes = chunk_bytes.strip() - if not chunk_bytes: - continue - - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") - if chunk != "[DONE]": - timestamp = time.perf_counter() - data = json.loads(chunk) + ], + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) - if choices := data.get("choices"): - content = choices[0]["delta"].get("content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) - generated_text += content or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") - most_recent_timestamp = timestamp + most_recent_timestamp = timestamp - output.generated_text = generated_text - output.success = True - output.latency = most_recent_timestamp - st - else: - output.error = response.reason or "" - output.success = False - except Exception: - output.success = False - exc_info = sys.exc_info() - output.error = "".join(traceback.format_exception(*exc_info)) + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) - if pbar: - pbar.update(1) return output diff --git a/utils/bench_serving/benchmark_serving.py b/utils/bench_serving/benchmark_serving.py index 38365dbfc..0a27d08b5 100644 --- a/utils/bench_serving/benchmark_serving.py +++ b/utils/bench_serving/benchmark_serving.py @@ -29,18 +29,22 @@ import gc import io import json +import math +import multiprocessing as mp import os import random import time +import traceback import warnings from dataclasses import dataclass from datetime import datetime from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple +import aiohttp import numpy as np -from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, - RequestFuncOutput) -from tqdm.asyncio import tqdm +from backend_request_func import (AIOHTTP_TIMEOUT, ASYNC_REQUEST_FUNCS, + RequestFuncInput, RequestFuncOutput) +from tqdm import tqdm from transformers import PreTrainedTokenizerBase try: @@ -53,7 +57,8 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_utils import convert_to_pytorch_benchmark_format +from benchmark_utils import (convert_to_pytorch_benchmark_format, + shard_round_robin) MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -245,7 +250,10 @@ def calculate_metrics( tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids) actual_output_lens.append(output_len) - total_input += input_requests[i][1] + # Use outputs[i].prompt_len rather than input_requests[i][1] so + # metrics don't depend on output order matching input order — + # workers return outputs as they complete, not in dispatch order. + total_input += outputs[i].prompt_len tpot = 0 if output_len > 1: latency_minus_ttft = outputs[i].latency - outputs[i].ttft @@ -321,14 +329,153 @@ def calculate_metrics( return metrics, actual_output_lens -async def benchmark( +# Per-worker batch size for streaming RequestFuncOutput back to main via mp.Queue. +# Batching amortizes pickling/lock overhead so the queue isn't the bottleneck +# at high QPS (at ~10k req/s, per-request puts contend on the queue's lock). +_WORKER_QUEUE_BATCH_SIZE = 64 + + +def _build_client_session(connector_limit: int) -> aiohttp.ClientSession: + connector = aiohttp.TCPConnector( + limit=connector_limit, + limit_per_host=connector_limit, + keepalive_timeout=300, + enable_cleanup_closed=True, + ) + return aiohttp.ClientSession(connector=connector, + trust_env=True, + timeout=AIOHTTP_TIMEOUT) + + +async def _run_warmup(request_func, test_input: RequestFuncInput, + num_warmups: int, max_concurrency: Optional[int], + disable_tqdm: bool): + pbar = None if disable_tqdm else tqdm(total=num_warmups, desc="warmup") + sem = asyncio.Semaphore(max_concurrency) if max_concurrency else None + limit = max_concurrency or 256 + async with _build_client_session(connector_limit=limit) as session: + async def _one(): + if sem is None: + out = await request_func(request_func_input=test_input, + session=session) + else: + async with sem: + out = await request_func(request_func_input=test_input, + session=session) + if pbar is not None: + pbar.update(1) + return out + await asyncio.gather(*[_one() for _ in range(num_warmups)]) + if pbar is not None: + pbar.close() + + +async def _one_off_request(request_func, req_input: RequestFuncInput + ) -> RequestFuncOutput: + async with _build_client_session(connector_limit=4) as session: + return await request_func(request_func_input=req_input, session=session) + + +def _worker_entry(worker_index: int, shard: List[Tuple], config: Dict[str, Any], + barrier: Any, result_queue: Any, seed: int) -> None: + """Subprocess entrypoint. Runs an asyncio loop over a shard of requests.""" + try: + random.seed(seed + worker_index) + np.random.seed(seed + worker_index) + asyncio.run(_worker_async(shard, config, barrier, result_queue)) + except BaseException: + traceback.print_exc() + finally: + # Sentinel: main drains until it has received one sentinel per worker. + # Must run even on exception so main doesn't hang. + try: + result_queue.put(None) + except Exception: + pass + + +async def _worker_async(shard: List[Tuple], config: Dict[str, Any], + barrier: Any, result_queue: Any) -> None: + request_func = ASYNC_REQUEST_FUNCS[config["backend"]] + request_rate = config["rate_per_worker"] + burstiness = config["burstiness"] + sem_size = config["max_concurrency_per_worker"] + theta = (1.0 / (request_rate * burstiness) + if request_rate != float("inf") else 0.0) + + batch_buffer: List[RequestFuncOutput] = [] + + def flush_batch(force: bool = False) -> None: + if not batch_buffer: + return + if force or len(batch_buffer) >= _WORKER_QUEUE_BATCH_SIZE: + result_queue.put(batch_buffer.copy()) + batch_buffer.clear() + + async with _build_client_session( + connector_limit=config["connector_limit"]) as session: + # Barrier synchronizes worker start across processes so aggregate QPS + # is measured from a single wall clock. Blocks the event loop briefly + # (only at startup, once), which is fine. + barrier.wait() + + sem = asyncio.Semaphore(sem_size) if sem_size else None + in_flight: set = set() + + async def _fire(req_input: RequestFuncInput) -> None: + try: + out = await request_func(request_func_input=req_input, + session=session) + batch_buffer.append(out) + flush_batch() + finally: + if sem is not None: + sem.release() + + for prompt, prompt_len, output_len, mm_content, lora_module in shard: + # Semaphore on the DISPATCH side (not inside the task) bounds + # in-flight tasks. Acquiring here prevents task pileup when + # requests complete slower than they arrive. + if sem is not None: + await sem.acquire() + + if request_rate != float("inf"): + interval = np.random.gamma(shape=burstiness, scale=theta) + await asyncio.sleep(interval) + + model_id = lora_module or config["model_id"] + model_name = lora_module or config["model_name"] + + req_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=prompt, + api_url=config["api_url"], + prompt_len=prompt_len, + output_len=output_len, + logprobs=config["logprobs"], + best_of=config["best_of"], + multi_modal_content=mm_content, + ignore_eos=config["ignore_eos"], + ) + + task = asyncio.create_task(_fire(req_input)) + in_flight.add(task) + task.add_done_callback(in_flight.discard) + + if in_flight: + await asyncio.gather(*in_flight) + flush_batch(force=True) + + +def run_benchmark( backend: str, api_url: str, base_url: str, model_id: str, model_name: str, tokenizer: PreTrainedTokenizerBase, - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[str, int, int, Any]], logprobs: Optional[int], best_of: int, request_rate: float, @@ -342,17 +489,18 @@ async def benchmark( goodput_config_dict: Dict[str, float], max_concurrency: Optional[int], lora_modules: Optional[List[str]], + num_client_workers: int, + client_connector_limit: int, + seed: int, ): - if backend in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[backend] - else: + if backend not in ASYNC_REQUEST_FUNCS: raise ValueError(f"Unknown backend: {backend}") + request_func = ASYNC_REQUEST_FUNCS[backend] print("Starting initial single prompt test run...") test_prompt, test_prompt_len, test_output_len, test_mm_content = ( input_requests[0]) if backend != "openai-chat" and test_mm_content is not None: - # multi-modal benchmark is only available on OpenAI Chat backend. raise ValueError( "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( @@ -370,42 +518,27 @@ async def benchmark( if num_warmups > 0: print(f"Warming up with {num_warmups} requests...") - warmup_pbar = None if disable_tqdm else tqdm(total=num_warmups) - warmup_semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else contextlib.nullcontext() - - async def warmup_limited_req_fn(): - async with warmup_semaphore: - return await request_func(request_func_input=test_input, pbar=warmup_pbar) - - warmup_tasks = [] - for _ in range(num_warmups): - task = asyncio.create_task(warmup_limited_req_fn()) - warmup_tasks.append(task) - _ = await asyncio.gather(*warmup_tasks) - - if warmup_pbar is not None: - warmup_pbar.close() + asyncio.run( + _run_warmup(request_func, test_input, num_warmups, max_concurrency, + disable_tqdm)) print("Warmup completed.") - if lora_modules: - # For each input request, choose a LoRA module at random. - lora_modules = iter( - [random.choice(lora_modules) for _ in range(len(input_requests))]) - if profile: print("Starting profiler...") - profile_input = RequestFuncInput(model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - extra_body={"num_steps": 1, "merge_profiles": True, "profile_by_stage": True}, - logprobs=logprobs, - best_of=best_of, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos) - profile_output = await request_func(request_func_input=profile_input) + profile_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_body={"num_steps": 1, "merge_profiles": True, "profile_by_stage": True}, + logprobs=logprobs, + best_of=best_of, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + ) + profile_output = asyncio.run(_one_off_request(request_func, profile_input)) if profile_output.success: print("Profiler started") @@ -417,54 +550,107 @@ async def warmup_limited_req_fn(): print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Maximum request concurrency: {max_concurrency}") + print(f"Client worker processes: {num_client_workers}") - pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + # Pre-resolve per-request LoRA module so workers don't need to share RNG. + if lora_modules: + lora_per_prompt = [random.choice(lora_modules) + for _ in range(len(input_requests))] + else: + lora_per_prompt = [None] * len(input_requests) - # This can be used once the minimum Python version is 3.10 or higher, - # and it will simplify the code in limited_request_func. - # semaphore = (asyncio.Semaphore(max_concurrency) - # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + expanded: List[Tuple[str, int, int, Any, Optional[str]]] = [ + (prompt, prompt_len, output_len, mm_content, lora) + for (prompt, prompt_len, output_len, mm_content), lora + in zip(input_requests, lora_per_prompt) + ] - async def limited_request_func(request_func_input, pbar): - if semaphore is None: - return await request_func(request_func_input=request_func_input, - pbar=pbar) - async with semaphore: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + shards = shard_round_robin(expanded, num_client_workers) + + rate_per_worker = (request_rate / num_client_workers + if request_rate != float("inf") else float("inf")) + mc_per_worker = (math.ceil(max_concurrency / num_client_workers) + if max_concurrency else None) + # Auto-size the per-worker aiohttp connector: must be at least as large as + # the in-flight cap, with a floor so unconstrained runs still get pooling. + if client_connector_limit and client_connector_limit > 0: + connector_limit = client_connector_limit + else: + connector_limit = max(256, mc_per_worker or 256) + + config = { + "backend": backend, + "api_url": api_url, + "model_id": model_id, + "model_name": model_name, + "logprobs": logprobs, + "best_of": best_of, + "ignore_eos": ignore_eos, + "rate_per_worker": rate_per_worker, + "burstiness": burstiness, + "max_concurrency_per_worker": mc_per_worker, + "connector_limit": connector_limit, + } + + ctx = mp.get_context("spawn") + # +1 for main so workers don't start timing the run before main's pbar + # and stopwatch are ready. + barrier = ctx.Barrier(num_client_workers + 1) + # Bounded queue provides back-pressure: if main can't drain fast enough, + # workers block on put rather than buffering unbounded results in memory. + result_queue = ctx.Queue(maxsize=num_client_workers * 32) + + processes = [] + for i, shard in enumerate(shards): + p = ctx.Process( + target=_worker_entry, + args=(i, shard, config, barrier, result_queue, seed), + ) + p.start() + processes.append(p) print("Starting main benchmark run...") + total = len(expanded) + pbar = None if disable_tqdm else tqdm(total=total, desc="bench") + # Wait on the barrier too so our stopwatch starts in sync with workers. + barrier.wait() benchmark_start_time = time.perf_counter() - tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate, burstiness): - prompt, prompt_len, output_len, mm_content = request - req_model_id, req_model_name = model_id, model_name - if lora_modules: - req_lora_module = next(lora_modules) - req_model_id, req_model_name = req_lora_module, req_lora_module - - request_func_input = RequestFuncInput(model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - best_of=best_of, - multi_modal_content=mm_content, - ignore_eos=ignore_eos) - tasks.append( - asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - pbar=pbar))) - outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + outputs: List[RequestFuncOutput] = [] + sentinels = 0 + while sentinels < num_client_workers: + item = result_queue.get() + if item is None: + sentinels += 1 + continue + outputs.extend(item) + if pbar is not None: + pbar.update(len(item)) + + benchmark_duration = time.perf_counter() - benchmark_start_time + + if pbar is not None: + pbar.close() + + for p in processes: + p.join(timeout=30) + if p.is_alive(): + warnings.warn( + f"Worker {p.pid} did not exit cleanly; terminating.", + stacklevel=2) + p.terminate() + + for p in processes: + if p.exitcode not in (0, None): + warnings.warn( + f"Worker {p.pid} exited with code {p.exitcode}; " + "aggregated results may be incomplete.", + stacklevel=2) if profile: print("Stopping profiler...") - profile_input = RequestFuncInput( + stop_profile_input = RequestFuncInput( model=model_id, prompt=test_prompt, api_url=base_url + "/stop_profile", @@ -473,15 +659,10 @@ async def limited_request_func(request_func_input, pbar): logprobs=logprobs, best_of=best_of, ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: + stop_output = asyncio.run(_one_off_request(request_func, stop_profile_input)) + if stop_output.success: print("Profiler stopped") - if pbar is not None: - pbar.close() - - benchmark_duration = time.perf_counter() - benchmark_start_time - metrics, actual_output_lens = calculate_metrics( input_requests=input_requests, outputs=outputs, @@ -674,31 +855,33 @@ def main(args: argparse.Namespace): gc.collect() gc.freeze() - benchmark_result = asyncio.run( - benchmark( - backend=backend, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - best_of=args.best_of, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - num_warmups=args.num_warmups, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - )) + benchmark_result = run_benchmark( + backend=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + best_of=args.best_of, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + num_warmups=args.num_warmups, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + num_client_workers=args.num_client_workers, + client_connector_limit=args.client_connector_limit, + seed=args.seed, + ) # Save config and results to json if args.save_result: @@ -1068,5 +1251,22 @@ def main(args: argparse.Namespace): parser.add_argument('--num-warmups', type=int, default=0) + parser.add_argument( + "--num-client-workers", + type=int, + default=min(os.cpu_count() or 1, 8), + help="Number of client worker processes. Each runs its own asyncio " + "loop and one shared aiohttp session. Defaults to min(cpu_count, 8). " + "Raise to drive higher QPS; single-process Python maxes out around " + "a few hundred QPS due to GIL/event-loop contention.") + + parser.add_argument( + "--client-connector-limit", + type=int, + default=0, + help="Per-worker aiohttp TCPConnector limit. 0 = auto (max(256, " + "max_concurrency/num_workers)). Raise if the client runs out of " + "sockets before the server is saturated.") + args = parser.parse_args() main(args) diff --git a/utils/bench_serving/benchmark_utils.py b/utils/bench_serving/benchmark_utils.py index dc6d31f6f..2ac4bc0b7 100644 --- a/utils/bench_serving/benchmark_utils.py +++ b/utils/bench_serving/benchmark_utils.py @@ -2,7 +2,20 @@ import argparse import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Sequence, TypeVar + +T = TypeVar("T") + + +def shard_round_robin(items: Sequence[T], n: int) -> List[List[T]]: + """Split `items` into `n` shards using round-robin assignment. + + Round-robin (vs. contiguous chunks) spreads short/long prompts evenly + across worker shards so per-worker load is balanced. + """ + if n <= 0: + raise ValueError(f"n must be positive, got {n}") + return [list(items[i::n]) for i in range(n)] def convert_to_pytorch_benchmark_format(args: argparse.Namespace, From ae1296d635d4bd7a84fb40dd716374a54860cbae Mon Sep 17 00:00:00 2001 From: Cam Quilici Date: Mon, 20 Apr 2026 13:30:57 -0500 Subject: [PATCH 2/2] improve benchmark serving pt 2 --- utils/bench_serving/benchmark_serving.py | 131 +++++------------------ 1 file changed, 28 insertions(+), 103 deletions(-) diff --git a/utils/bench_serving/benchmark_serving.py b/utils/bench_serving/benchmark_serving.py index 0a27d08b5..8434ce4e1 100644 --- a/utils/bench_serving/benchmark_serving.py +++ b/utils/bench_serving/benchmark_serving.py @@ -14,8 +14,7 @@ python benchmarks/benchmark_serving.py \ --backend \ --model \ - --dataset-name sharegpt \ - --dataset-path \ + --dataset-name random \ --request-rate \ # By default is inf --num-prompts # By default is 1000 @@ -432,7 +431,7 @@ async def _fire(req_input: RequestFuncInput) -> None: if sem is not None: sem.release() - for prompt, prompt_len, output_len, mm_content, lora_module in shard: + for prompt, prompt_len, output_len, mm_content in shard: # Semaphore on the DISPATCH side (not inside the task) bounds # in-flight tasks. Acquiring here prevents task pileup when # requests complete slower than they arrive. @@ -443,12 +442,9 @@ async def _fire(req_input: RequestFuncInput) -> None: interval = np.random.gamma(shape=burstiness, scale=theta) await asyncio.sleep(interval) - model_id = lora_module or config["model_id"] - model_name = lora_module or config["model_name"] - req_input = RequestFuncInput( - model=model_id, - model_name=model_name, + model=config["model_id"], + model_name=config["model_name"], prompt=prompt, api_url=config["api_url"], prompt_len=prompt_len, @@ -488,7 +484,6 @@ def run_benchmark( ignore_eos: bool, goodput_config_dict: Dict[str, float], max_concurrency: Optional[int], - lora_modules: Optional[List[str]], num_client_workers: int, client_connector_limit: int, seed: int, @@ -552,20 +547,7 @@ def run_benchmark( print(f"Maximum request concurrency: {max_concurrency}") print(f"Client worker processes: {num_client_workers}") - # Pre-resolve per-request LoRA module so workers don't need to share RNG. - if lora_modules: - lora_per_prompt = [random.choice(lora_modules) - for _ in range(len(input_requests))] - else: - lora_per_prompt = [None] * len(input_requests) - - expanded: List[Tuple[str, int, int, Any, Optional[str]]] = [ - (prompt, prompt_len, output_len, mm_content, lora) - for (prompt, prompt_len, output_len, mm_content), lora - in zip(input_requests, lora_per_prompt) - ] - - shards = shard_round_robin(expanded, num_client_workers) + shards = shard_round_robin(input_requests, num_client_workers) rate_per_worker = (request_rate / num_client_workers if request_rate != float("inf") else float("inf")) @@ -610,7 +592,7 @@ def run_benchmark( processes.append(p) print("Starting main benchmark run...") - total = len(expanded) + total = len(input_requests) pbar = None if disable_tqdm else tqdm(total=total, desc="bench") # Wait on the barrier too so our stopwatch starts in sync with workers. @@ -835,20 +817,19 @@ def main(args: argparse.Namespace): trust_remote_code=args.trust_remote_code) - if args.dataset_name == "random": - input_requests = sample_random_requests( - prefix_len=args.random_prefix_len, - input_len=args.random_input_len, - output_len=args.random_output_len, - num_prompts=args.num_prompts, - range_ratio=args.random_range_ratio, - tokenizer=tokenizer, - use_chat_template=args.use_chat_template, - ) - - else: + if args.dataset_name != "random": raise ValueError(f"Unknown dataset: {args.dataset_name}") + input_requests = sample_random_requests( + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + use_chat_template=args.use_chat_template, + ) + goodput_config_dict = check_goodput_args(args) # Avoid GC processing "static" data - reduce pause times. @@ -877,7 +858,6 @@ def main(args: argparse.Namespace): ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, num_client_workers=args.num_client_workers, client_connector_limit=args.client_connector_limit, seed=args.seed, @@ -972,15 +952,10 @@ def main(args: argparse.Namespace): parser.add_argument( "--dataset-name", type=str, - default="sharegpt", + default="random", choices=["random"], help="Name of the dataset to benchmark on.", ) - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.") parser.add_argument( "--max-concurrency", type=int, @@ -1139,38 +1114,6 @@ def main(args: argparse.Namespace): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve") - # group for dataset specific arguments - sonnet_group = parser.add_argument_group("sonnet dataset options") - sonnet_group.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", - ) - - sharegpt_group = parser.add_argument_group("sharegpt dataset options") - sharegpt_group.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") - random_group = parser.add_argument_group("random dataset options") random_group.add_argument( "--random-input-len", @@ -1207,23 +1150,6 @@ def main(args: argparse.Namespace): help="Use chat template to format the prompt.", ) - hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - hf_group.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") - hf_group.add_argument( - "--hf-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output lengths " - "from the sampled HF dataset.", - ) - parser.add_argument( '--tokenizer-mode', type=str, @@ -1242,23 +1168,22 @@ def main(args: argparse.Namespace): "If not specified, the model name will be the " "same as the ``--model`` argument. ") - parser.add_argument("--lora-modules", - nargs='+', - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.") - parser.add_argument('--num-warmups', type=int, default=0) + # Cap the auto-detected default so a 128-vCPU host doesn't spawn 128 + # workers by accident. Override via BENCH_CLIENT_WORKERS_CAP env var when + # you actually want more (or fewer) by default. + _default_workers_cap = int(os.environ.get("BENCH_CLIENT_WORKERS_CAP", "8")) parser.add_argument( "--num-client-workers", type=int, - default=min(os.cpu_count() or 1, 8), + default=min(os.cpu_count() or 1, _default_workers_cap), help="Number of client worker processes. Each runs its own asyncio " - "loop and one shared aiohttp session. Defaults to min(cpu_count, 8). " - "Raise to drive higher QPS; single-process Python maxes out around " - "a few hundred QPS due to GIL/event-loop contention.") + f"loop and one shared aiohttp session. Defaults to min(cpu_count, " + f"{_default_workers_cap}) — the cap is set via " + "BENCH_CLIENT_WORKERS_CAP (default 8). Raise to drive higher QPS; " + "single-process Python maxes out around a few hundred QPS due to " + "GIL/event-loop contention.") parser.add_argument( "--client-connector-limit",