diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 77d07df4..96f8a987 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -196,4 +196,22 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) { this->compile(); } +std::vector> InferEngine::get_kv_cache() { + std::vector> kv_cache_list; + if (workers_.empty()) { + throw std::runtime_error("InferEngine::get_cache_vec: no workers"); + } + + kv_cache_list.reserve(workers_.size()); + for (auto &worker : workers_) { + kv_cache_list.push_back(std::move(worker->get_kv_cache())); + } + + for (auto &worker : workers_) { + worker->wait(); + } + + return kv_cache_list; +} + } // namespace infinilm::engine diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 70c3c164..a6ef289f 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -49,6 +49,8 @@ class InferEngine { void reset_cache(const cache::CacheConfig *new_config); + std::vector> get_kv_cache(); + ~InferEngine(); const distributed::DistConfig &get_dist_config() const; diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 1fa34e12..111185a4 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -208,6 +208,22 @@ void RankWorker::reset_cache(const cache::CacheConfig *new_config) { cv_.notify_all(); } +//------------------------------------------------------ +// get kv cache +//------------------------------------------------------ +std::vector RankWorker::get_kv_cache() { + std::unique_lock lk(mutex_); + cv_.wait(lk, [&] { return init_done_ || should_exit_; }); + + if (should_exit_) { + throw std::runtime_error("RankWorker stopped; cannot get_cache_vec"); + } + + ASSERT(forward_context_.kv_cache_vec.size() > 0 && "RankWorker::get_kv_cache(): kv_cache_vec is empty"); + + return forward_context_.kv_cache_vec; +} + //------------------------------------------------------ // close -- request shutdown and join thread //------------------------------------------------------ diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index c214536f..48e34681 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -98,6 +98,8 @@ class RankWorker { // Reset the internal cache with a new configuration void reset_cache(const cache::CacheConfig *new_config); + std::vector get_kv_cache(); + // Compile the model graph if enabled. void compile(); diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 6022f25e..74809c3f 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -102,12 +102,17 @@ inline void bind_infer_engine(py::module &m) { .def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)") .def( "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { + // IMPORTANT: Release the GIL before calling forward() to allow other Python threads + // to run concurrently during inference (which may block for a long time). + // Do NOT remove this — without it, the GIL is held throughout inference and will + // deadlock or stall any other Python thread (e.g., request handling, scheduling). py::gil_scoped_release release; return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") .def( "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("get_kv_cache", &InferEngine::get_kv_cache, "Get per-rank kv cache list") .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index 018e17a4..c07beff6 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -99,6 +99,8 @@ def __init__(self): self.port = self.args.port self.endpoint = self.args.endpoint self.ignore_eos = self.args.ignore_eos + # PD separation (KV transfer) + self.kv_transfer_config = self.args.kv_transfer_config # Multimodal parameters self.image = self.args.image @@ -268,6 +270,19 @@ def _add_common_args(self): help="image path for multimodal models", ) + # ---- PD separation arguments ---- + self.parser.add_argument( + "--kv-transfer-config", + type=str, + default=None, + help=( + "JSON object for KVTransferConfig. Allowed keys only: " + "kv_connector, engine_id, kv_role, kv_connector_extra_config (omit any for defaults). " + "Example: " + '\'{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}\'' + ), + ) + def get_device_str(self, device): """Convert device name to backend string (cuda/cpu/musa/mlu)""" DEVICE_STR_MAP = { diff --git a/python/infinilm/config/__init__.py b/python/infinilm/config/__init__.py new file mode 100644 index 00000000..bb53bfd5 --- /dev/null +++ b/python/infinilm/config/__init__.py @@ -0,0 +1,4 @@ +from .engine_config import EngineConfig +from .kv_transfer import KVTransferConfig + +__all__ = ["EngineConfig", "KVTransferConfig"] diff --git a/python/infinilm/config/engine_config.py b/python/infinilm/config/engine_config.py new file mode 100644 index 00000000..5799b5b0 --- /dev/null +++ b/python/infinilm/config/engine_config.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass +from typing import Optional +from infinilm.config.kv_transfer import KVTransferConfig + + +@dataclass +class EngineConfig: + """Configuration for LLM Engine. + + Attributes: + model_path: Path to the model directory. + device: Device type string ('cpu', 'cuda', 'mlu', etc.). + dtype: Data type string ('float16', 'bfloat16', 'float32'). + tensor_parallel_size: Number of devices for tensor parallelism. + cache_type: Cache type ('paged' or 'static'). + max_batch_size: Maximum batch size for inference (only for paged cache). + max_tokens: Default maximum tokens to generate. + num_blocks: Number of KV cache blocks (only for paged cache). + block_size: Size of each KV cache block (only for paged cache). + max_cache_len: Maximum sequence length (only for static cache). + temperature: Default sampling temperature. + top_p: Default top-p sampling parameter. + top_k: Default top-k sampling parameter. + enable_graph: Whether to enable graph compiling. + attn_backend: Attention backend to use ('default', 'flash-attn'). + skip_load: Whether to skip loading model weights (for testing). + """ + + model_path: str + device: str = "cuda" + dtype: str = "float16" + tensor_parallel_size: int = 1 + cache_type: str = "paged" # "paged" or "static" + max_batch_size: int = 16 + max_tokens: int = 4096 + num_blocks: int = 512 + block_size: int = 256 + max_cache_len: int = 4096 + temperature: float = 1.0 + top_p: float = 0.8 + top_k: int = 1 + enable_graph: bool = False + attn_backend: str = "default" + skip_load: bool = False + kv_transfer_config: Optional[KVTransferConfig] = None + + def __post_init__(self) -> None: + if ( + self.kv_transfer_config is not None + and self.kv_transfer_config.kv_connector + and self.cache_type != "paged" + ): + raise ValueError("kv_transfer_config requires cache_type='paged'") diff --git a/python/infinilm/config/kv_transfer.py b/python/infinilm/config/kv_transfer.py new file mode 100644 index 00000000..6b222754 --- /dev/null +++ b/python/infinilm/config/kv_transfer.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright 2026 InfiniLM Contributors + +import uuid +from dataclasses import dataclass, field +from typing import Optional +import os + +KV_ROLE_CHOICES = frozenset({"kv_producer", "kv_consumer"}) + + +@dataclass +class KVTransferConfig: + """Configuration for KV cache transfer in prefill/decode (P/D) separation. + + Attributes: + kv_connector: Name of the KV connector to use (e.g. 'MooncakeConnector'). + None disables KV transfer. + kv_role: Role of this node: 'kv_producer' (prefill) or 'kv_consumer' (decode). + engine_id: Unique identifier for this engine instance used in KV transfers. + Auto-generated (UUID) if not provided. + kv_connector_extra_config: Extra configuration dict passed to the connector backend. + """ + + kv_connector: Optional[str] = None + kv_role: Optional[str] = None + engine_id: Optional[str] = None + kv_connector_extra_config: Optional[dict] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_role when kv_connector is set.") + + if self.kv_role is not None and self.kv_role not in KV_ROLE_CHOICES: + raise ValueError( + f"Unsupported kv_role: {self.kv_role!r}. " + f"Supported roles are {sorted(KV_ROLE_CHOICES)}" + ) + + if self.engine_id is None: + self.engine_id = f"{self.kv_role}_" + str(uuid.uuid4()) + + self.kv_connector_extra_config = dict(self.kv_connector_extra_config or {}) + self.kv_connector_extra_config.setdefault("mooncake_protocol", "rdma") + + allowed_extra_config_keys = frozenset({"mooncake_protocol", "num_workers"}) + unknown_keys = set(self.kv_connector_extra_config.keys()) - allowed_extra_config_keys + if unknown_keys: + raise ValueError( + f"Unsupported kv_connector_extra_config keys: {sorted(unknown_keys)}. " + f"Supported keys are {sorted(allowed_extra_config_keys)}" + ) + + mooncake_protocol = self.kv_connector_extra_config["mooncake_protocol"] + if mooncake_protocol not in ["tcp", "rdma"]: + raise ValueError(f"only support tcp or rdma, but got {mooncake_protocol}") + + if mooncake_protocol == "tcp": + # NOTE: force use tcp for Mooncake + os.environ["MC_FORCE_TCP"] = "true" diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 75890b0f..ea44d2f0 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -389,3 +389,18 @@ def load_state_dict(self, state_dict, strict=None): def process_weights_after_loading(self): super().process_weights_after_loading() + + def get_kv_cache(self) -> list[list[infinicore.Tensor]]: + """ + get per-rank kv cache. + """ + kv_cache_list = super().get_kv_cache() + infinicore.sync_device() + + result = [] + for rank_idx, kv_caches_per_rank in enumerate(kv_cache_list): + result_rank = [] + for layer_idx, layer_kv in enumerate(kv_caches_per_rank): + result_rank.append(infinicore.Tensor(layer_kv)) + result.append(result_rank) + return result diff --git a/python/infinilm/kv_connector/__init__.py b/python/infinilm/kv_connector/__init__.py new file mode 100644 index 00000000..b5b77bb4 --- /dev/null +++ b/python/infinilm/kv_connector/__init__.py @@ -0,0 +1,36 @@ +""" +KV connector package. + +This module: +- Exposes core KV connector abstractions (base, role, metadata) +- Provides the KVConnectorFactory +- Registers built-in connectors (e.g. MooncakeConnector) + +Note: +Importing this module will trigger connector registration. +""" + +from infinilm.kv_connector.base import ( + KVConnectorBase, + KVConnectorRole, + KVConnectorMetadata, + KVConnectorHandshakeMetadata, + KVConnectorWorkerMetadata, +) +from infinilm.kv_connector.factory import KVConnectorFactory + +KVConnectorFactory.register_connector( + "MooncakeConnector", + "infinilm.kv_connector.mooncake.mooncake_connector", + "MooncakeConnector", +) + + +__all__ = [ + "KVConnectorBase", + "KVConnectorRole", + "KVConnectorMetadata", + "KVConnectorHandshakeMetadata", + "KVConnectorWorkerMetadata", + "KVConnectorFactory", +] diff --git a/python/infinilm/kv_connector/base.py b/python/infinilm/kv_connector/base.py new file mode 100644 index 00000000..6126e252 --- /dev/null +++ b/python/infinilm/kv_connector/base.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright 2026 InfiniLM Contributors + +""" +KVConnector abstract base class. + +All KV transfer backends (e.g. LMCache, Mooncake, NIXL) must subclass this. +The scheduler invokes connector hook points; each concrete implementation +determines the transfer behaviour. +""" + +import enum +import logging + +from abc import ABC, abstractmethod +from typing import Any, Optional + +from infinilm.llm.request import InferenceRequest +from infinilm.config.kv_transfer import KVTransferConfig +import infinicore + +logger = logging.getLogger(__name__) + + +class KVConnectorRole(enum.Enum): + SCHEDULER = 0 + WORKER = 1 + + +class KVConnectorHandshakeMetadata: + """ + Metadata used for out of band connector handshake between + P/D workers. This needs to serializable. + """ + + +class KVConnectorMetadata: + """ + Abstract Metadata used to communicate + Scheduler KVConnector -> Worker KVConnector. + """ + + +class KVConnectorWorkerMetadata(ABC): + """ + Abstract Metadata used to communicate back + Worker KVConnector -> Scheduler KVConnector. + + Each worker can output its own metadata. + For a single engine step, all metadata objects returned by workers + will be aggregated using the `aggregate` method below, before + being passed to the Scheduler KVConnector. + """ + + @abstractmethod + def aggregate( + self, other: "KVConnectorWorkerMetadata" + ) -> "KVConnectorWorkerMetadata": + """ + Aggregate metadata with another `KVConnectorWorkerMetadata` object. + """ + pass + + +class KVConnectorBase(ABC): + """ + Base class for KV connectors. + """ + + @property + def prefer_cross_layer_blocks(self) -> bool: + """ + Indicates whether this connector prefers KV blocks that hold KV data for all + layers, which can speed up KV data transfers. Defaults to False. + """ + return False + + def __init__( + self, + role: KVConnectorRole, + kv_transfer_config: KVTransferConfig | None = None, + ): + """ + Args: + role: The role of the connector, either SCHEDULER or WORKER. + kv_transfer_config: KV transfer configuration containing kv_role and connector_config. + """ + self._connector_metadata: KVConnectorMetadata | None = None + cfg = kv_transfer_config or KVTransferConfig() + self._role = role + self._kv_transfer_config = cfg + + @property + def role(self) -> KVConnectorRole: + return self._role + + @property + def kv_transfer_config(self) -> KVTransferConfig: + """PD/KV transfer options (not EngineConfig—use this name, not ``.config``).""" + return self._kv_transfer_config + + # ============================== + # Scheduler-side methods + # ============================== + + @abstractmethod + def get_num_new_matched_tokens( + self, + request: InferenceRequest, + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (InferenceRequest): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + A tuple with the following elements: + - An optional number of tokens that can be loaded from the + external KV cache beyond what is already computed. + If None, it means that the connector needs more time to + determine the number of matched tokens, and the scheduler + should query for this request again later. + - `True` if external KV cache tokens will be loaded + asynchronously (between scheduler steps). Must be + 'False' if the first element is 0. + + Notes: + The connector should only consider the largest prefix of prompt- + tokens for which KV cache is actually available at the time of the + call. If the cache cannot be loaded for some tokens (e.g., due to + connectivity issues or eviction), those tokens must not be taken + into account. + """ + pass + + @abstractmethod + def update_state_after_alloc( + self, + request: InferenceRequest, + block_ids: list[int], + num_external_tokens: int, + block_size: Optional[int] = None, + ) -> None: + """ + Update KVConnector state after block allocation. + + If get_num_new_matched_tokens previously returned True for a + request, this function may be called twice for that same request - + first when blocks are allocated for the connector tokens to be + asynchronously loaded into, and second when any additional blocks + are allocated, after the load/transfer is complete. + + Args: + request (InferenceRequest): the request object. + block_ids (list[int]): the block IDs allocated for the request. + num_external_tokens (int): the number of tokens that will be + loaded from the external KV cache. + block_size (Optional[int]): the size of each block. This is used + to calculate the number of blocks needed for the external tokens. + """ + pass + + @abstractmethod + def build_connector_meta(self) -> KVConnectorMetadata: + """Build the connector metadata for this step.""" + pass + + @abstractmethod + def request_finished( + self, + request: InferenceRequest, + block_ids: list[int], + block_size: Optional[int] = None, + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called exactly once when a request has finished, before its blocks are + freed. + + The connector may assumes responsibility for freeing the blocks + asynchronously by returning True. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + pass + + # ============================== + # Worker-side methods + # ============================== + + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. + + Args: + connector_metadata: the connector metadata. + """ + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self._connector_metadata = None + + @abstractmethod + def register_kv_caches(self, kv_caches: dict[str, infinicore.Tensor]) -> None: + """Register KV cache tensors of the connector. + + Args: + kv_caches: Mapping from layer name to KV cache tensor. + """ + pass + + @abstractmethod + def start_load_kv(self, **kwargs: Any) -> None: + """ + Start loading KV cache from the connector to the KV buffer. + This is called before the forward pass to enable async loading + during model execution. + + Args: + **kwargs: additional arguments for the load operation + """ + pass + + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Get the set of block IDs that failed to load. + + """ + return set() + + def get_kv_connector_stats(self): + """ + Get the KV connector stats collected during the last interval. + """ + return None + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None, set[str] | None]: + return None, None, None + + def shutdown(self): + """ + Shutdown the connector. + """ + return None diff --git a/python/infinilm/kv_connector/factory.py b/python/infinilm/kv_connector/factory.py new file mode 100644 index 00000000..96a18160 --- /dev/null +++ b/python/infinilm/kv_connector/factory.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright 2026 InfiniLM Contributors + +""" +KVConnectorFactory — lazy-loading registry for KV connectors. +""" + +import importlib +import logging +from collections.abc import Callable + +from infinilm.kv_connector.base import KVConnectorBase, KVConnectorRole +from infinilm.config.kv_transfer import KVTransferConfig + +logger = logging.getLogger(__name__) + + +class KVConnectorFactory: + """Registry and factory for KV connectors (lazy-loaded).""" + + _registry: dict[str, Callable[..., KVConnectorBase]] = {} + + @classmethod + def register_connector( + cls, + name: str, + module_path: str, + class_name: str, + ) -> None: + """Register a KV connector backend.""" + if name in cls._registry: + logger.warning(f"KVConnector '{name}' already registered, overwriting") + + def _lazy_loader(**kwargs) -> KVConnectorBase: + module = importlib.import_module(module_path) + connector_cls = getattr(module, class_name) + return connector_cls(**kwargs) + + cls._registry[name] = _lazy_loader + logger.debug(f"Registered KVConnector: {name} -> {module_path}.{class_name}") + + @classmethod + def create_connector( + cls, + connector_name: str, + role: KVConnectorRole, + kv_transfer_config: KVTransferConfig | None = None, + ) -> KVConnectorBase: + """Create a registered connector.""" + if connector_name not in cls._registry: + raise ValueError(f"Unknown KVConnector: '{connector_name}'.") + + cfg = kv_transfer_config or KVTransferConfig() + loader = cls._registry[connector_name] + connector = loader( + role=role, + kv_transfer_config=cfg, + ) + logger.info( + f"Created KVConnector: {connector_name} " + f"(role={role.name}, kv_role={cfg.kv_role})" + ) + return connector + + @classmethod + def get_available_connectors(cls) -> list: + return list(cls._registry.keys()) diff --git a/python/infinilm/kv_connector/mooncake/mooncake_connector.py b/python/infinilm/kv_connector/mooncake/mooncake_connector.py new file mode 100644 index 00000000..167c599a --- /dev/null +++ b/python/infinilm/kv_connector/mooncake/mooncake_connector.py @@ -0,0 +1,164 @@ +import logging +from dataclasses import dataclass +from collections import defaultdict +from typing import Any, Optional + +from infinilm.kv_connector import ( + KVConnectorBase, + KVConnectorMetadata, + KVConnectorRole, +) +from infinilm.config.kv_transfer import KVTransferConfig + +import infinicore + +from infinilm.llm import InferenceRequest + +logger = logging.getLogger(__name__) + +ReqId = str +TransferId = str +EngineId = str +WorkerAddr = str + + +@dataclass +class PullReqMeta: + d_req_id: ReqId + transfer_id: TransferId + local_block_ids: list[int] + remote_engine_id: str + remote_bootstrap_addr: str + expire_time: float = float("inf") # not used + pull_tasks_count: int = 0 + + +class MooncakeConnectorMetadata(KVConnectorMetadata): + def __init__(self): + self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict) + self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {} + self.reqs_not_processed: set[TransferId] = set() + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, str], + load_remote_cache: bool = True, + ): + transfer_id = kv_transfer_params["transfer_id"] + if load_remote_cache: + remote_engine_id = kv_transfer_params["remote_engine_id"] + self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta( + d_req_id=request_id, + local_block_ids=local_block_ids, + remote_engine_id=remote_engine_id, + remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"], + transfer_id=transfer_id, + ) + else: + self.reqs_to_send[request_id] = (transfer_id, local_block_ids) + + def __str__(self) -> str: + return ( + f"MooncakeConnectorMetadata(reqs_to_recv={dict(self.reqs_to_recv)}, " + f"reqs_to_send={self.reqs_to_send}, " + f"reqs_not_processed={self.reqs_not_processed})" + ) + + +class MooncakeConnector(KVConnectorBase): + def __init__( + self, + role: KVConnectorRole, + kv_transfer_config: KVTransferConfig, + ): + assert kv_transfer_config is not None + cfg = kv_transfer_config + super().__init__( + role=role, + kv_transfer_config=cfg, + ) + + self.engine_id: EngineId | None = cfg.engine_id + + logger.info( + "MooncakeConnector::__init__ kv_transfer_config=%s role=%s engine_id=%s", + cfg, + role, + self.engine_id, + ) + + if role == KVConnectorRole.SCHEDULER: + from infinilm.kv_connector.mooncake.mooncake_connector_scheduler import ( + MooncakeConnectorScheduler, + ) + + self.connector_scheduler: "MooncakeConnectorScheduler | None" = ( + MooncakeConnectorScheduler(cfg, engine_id=self.engine_id) + ) + self.connector_worker: "MooncakeConnectorWorker | None" = None + else: + from infinilm.kv_connector.mooncake.mooncake_connector_worker import ( + MooncakeConnectorWorker, + ) + + self.connector_scheduler = None + self.connector_worker = MooncakeConnectorWorker( + cfg, engine_id=self.engine_id + ) + + def get_num_new_matched_tokens( + self, request: InferenceRequest, num_computed_tokens: int + ) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens + ) + + def update_state_after_alloc( + self, + request: InferenceRequest, + block_ids: list[int], + num_external_tokens: int, + block_size: Optional[int] = None, + ) -> None: + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, block_ids, num_external_tokens, block_size + ) + + def build_connector_meta( + self, + ) -> KVConnectorMetadata | None: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta() + + def request_finished( + self, + request: InferenceRequest, + block_ids: list[int], + block_size: Optional[int] = None, + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids, block_size) + + def register_kv_caches(self, kv_caches: dict[str, infinicore.Tensor]) -> None: + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv( + self, + **kwargs, + ) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def get_finished( + self, + finished_req_ids: set[str], # noqa: ARG002 + ) -> tuple[set[str] | None, set[str] | None]: + """Return finished receive and send request id sets, if any.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() diff --git a/python/infinilm/kv_connector/mooncake/mooncake_connector_scheduler.py b/python/infinilm/kv_connector/mooncake/mooncake_connector_scheduler.py new file mode 100644 index 00000000..a9b376f3 --- /dev/null +++ b/python/infinilm/kv_connector/mooncake/mooncake_connector_scheduler.py @@ -0,0 +1,226 @@ +import logging +from typing import Any, Optional + +from infinilm.kv_connector import ( + KVConnectorMetadata, +) +from infinilm.config.kv_transfer import KVTransferConfig +from infinilm.llm import InferenceRequest, RequestStatus + +logger = logging.getLogger(__name__) + +ReqId = str +TransferId = str + + +class MooncakeConnectorScheduler: + def __init__(self, kv_transfer_config: KVTransferConfig, engine_id: str): + assert kv_transfer_config is not None + assert engine_id is not None + + self.kv_transfer_config = kv_transfer_config + self.engine_id = engine_id + + self.is_kv_producer = kv_transfer_config.kv_role == "kv_producer" + self.is_kv_consumer = kv_transfer_config.kv_role == "kv_consumer" + + logger.info("Initializing MooncakeConnector Scheduler %s", engine_id) + + # Requests that need to start recv/send. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[ReqId, tuple[InferenceRequest, list[int]]] = {} + self._reqs_need_send: dict[ReqId, tuple[InferenceRequest, list[int]]] = {} + # Reqs to remove from processed set because they're not to send after + # remote prefill or aborted. + self._reqs_not_processed: set[TransferId] = set() + + def get_num_new_matched_tokens( + self, request: InferenceRequest, num_computed_tokens: int + ) -> tuple[int, bool]: + """ + Args: + request(InferenceRequest): the request object. + num_computed_tokens(int): the number of locally computed tokens for this request + Returns: + * the number of tokens that can be loaded from the external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded asynchronously (between scheduler steps). + """ + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, + params, + ) + + if not params: + return 0, False + + if params.get("do_remote_prefill"): + assert not self.is_kv_producer + token_ids = request.prompt_token_ids or [] + count = len(token_ids) - num_computed_tokens + if count > 0: + return count, True + + return 0, False + + def update_state_after_alloc( + self, + request: InferenceRequest, + block_ids: list[int], + num_external_tokens: int, + block_size: Optional[int] = None, + ) -> None: + """ + Args: + request: the request object. + block_ids: the list of block IDs allocated for the request, + specifically for the tokens that will be loaded from the external KV cache. + num_external_tokens: the number of tokens that will be loaded + from the external KV cache. + """ + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector update_state_after_alloc: " + "req_id=%s num_external_tokens=%s, kv_transfer_params=%s", + request.request_id, + num_external_tokens, + params, + ) + + if not params: + return + + if params.get("do_remote_prefill"): + assert not self.is_kv_producer + if all( + p in params + for p in ("remote_engine_id", "remote_bootstrap_addr", "transfer_id") + ): + # If remote_blocks and num_external_tokens = 0, we have + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + # remote_block_ids = block_ids if num_external_tokens > 0 else [] + if num_external_tokens > 0: + assert ( + block_size is not None + ), "block_size must be provided when num_external_tokens > 0" + prompt_len = request.get_prompt_length() + local_computed_tokens = prompt_len - num_external_tokens + assert ( + local_computed_tokens % block_size == 0 + ), "local_computed_tokens must be divisible by block_size" + start_idx = local_computed_tokens // block_size + remote_block_ids = block_ids[start_idx:] + else: + remote_block_ids = [] + self._reqs_need_recv[request.request_id] = (request, remote_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", + params, + ) + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + elif params.get("do_remote_decode"): + assert not self.is_kv_consumer + if not params.get("transfer_id"): + logger.warning("Missing transfer_id in kv_transfer_params from router!") + else: + self._reqs_need_send[request.request_id] = (request, []) + + def build_connector_meta( + self, + ) -> KVConnectorMetadata | None: + from .mooncake_connector import MooncakeConnectorMetadata + + meta = MooncakeConnectorMetadata() + + if not self.is_kv_producer: + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + self._reqs_need_recv.clear() + + if not self.is_kv_consumer: + for req_id, (req, block_ids) in self._reqs_need_send.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + load_remote_cache=False, + ) + self._reqs_need_send.clear() + meta.reqs_not_processed = self._reqs_not_processed + self._reqs_not_processed = set() + + return meta + + def request_finished( + self, + request: InferenceRequest, + block_ids: list[int], + block_size: Optional[int] = None, + ) -> tuple[bool, dict[str, Any] | None]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + + Returns: + A tuple of (delay_free_blocks, extra_info) + - delay_free_blocks: whether to delay freeing blocks until async transfer is done. + - extra_info: additional info for the caller, currently unused. + """ + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector request_finished: req_id=%s, request_status=%s, " + "kv_transfer_params=%s", + request.request_id, + request.status, + params, + ) + + if not params or not params.get("transfer_id"): + return False, None + + # Consumer-side error handling. + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + assert not self.is_kv_producer + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if not params.get("do_remote_decode"): + return False, None + + # Producer-side error and normal handling. + assert not self.is_kv_consumer + if request.status != RequestStatus.FINISHED: + self._reqs_not_processed.add(params["transfer_id"]) + return False, None + + delay_free_blocks = len(block_ids) > 0 + if delay_free_blocks: + assert ( + block_size is not None + ), "block_size must be provided when delay_free_blocks is True" + self._reqs_need_send[request.request_id] = ( + request, + block_ids, + ) + return delay_free_blocks, None diff --git a/python/infinilm/kv_connector/mooncake/mooncake_connector_worker.py b/python/infinilm/kv_connector/mooncake/mooncake_connector_worker.py new file mode 100644 index 00000000..2640aaf3 --- /dev/null +++ b/python/infinilm/kv_connector/mooncake/mooncake_connector_worker.py @@ -0,0 +1,1157 @@ +try: + from mooncake.engine import TransferEngine +except ImportError as e: + raise ImportError("Please pip install mooncake-transfer-engine") from e + +import asyncio +import logging +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +import httpx +import infinicore +import numpy as np +import msgspec +import zmq +import zmq.asyncio +from dataclasses import dataclass + +from infinilm.kv_connector.mooncake.mooncake_connector import ( + MooncakeConnectorMetadata, + PullReqMeta, +) + +from infinilm.kv_connector.mooncake.mooncake_utils import ( + get_ip, + get_mooncake_bootstrap_addr, + make_zmq_path, + make_zmq_socket, + MooncakeBootstrapServer, + RegisterWorkerPayload, + should_launch_bootstrap_server, +) +from enum import IntEnum +from typing import Any + +logger = logging.getLogger(__name__) + + +EngineId = str +ReqId = str +TransferId = str + + +class MooncakeXferResponseStatus(IntEnum): + # Transfer finished + FINISH = 0 + # Continue to receive + CONTINUE = 1 + # Something wrong, see err_msg + ERROR = 2 + + +class MooncakeXferReqStatus(IntEnum): + SUCCESS = 0 # normal + TIMEOUT = 1 # P node task timeout + ADDR_MISMATCH = 2 # address calculation failed before sending + XFER_FAIL = 3 # mooncake write data failed + + +class MooncakeXferResponse( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] +): + status: MooncakeXferResponseStatus + reqs_ids: list[ReqId] | None = None + reqs_statues: list[MooncakeXferReqStatus] | None = None + msg: str | None = None + + +class MooncakeXferMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] +): + remote_hostname: str + remote_port: int + remote_tp_size: int + remote_tp_rank: int + req_blocks: dict[ReqId, tuple[TransferId, list[int]]] + kv_caches_base_addr: list[int] + kv_flag_addr: list[int] + + +@dataclass +class SendBlockMeta: + p_req_id: ReqId + transfer_id: TransferId + local_block_ids: list[int] + ready: asyncio.Event + expire_time: float = float("inf") + need_send: int = 0 + sent: int = 0 + sending: int = 0 + + +class MooncakeAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): + remote_hostname: str + remote_port: int + request_ids: list[ReqId] + kv_caches_base_addr: list[int] + block_ids: list[list[int]] + + +@dataclass +class MooncakeParallelConfig: + """Configuration for Mooncake distributed execution.""" + + tensor_parallel_size: int = 1 + tensor_parallel_rank: int = 0 + world_size: int = 1 + rank: int = 0 + + +class MooncakeConnectorWorker: + def __init__(self, kv_transfer_config, engine_id: str) -> None: + assert kv_transfer_config is not None + assert engine_id is not None + logger.info("Initializing MooncakeConnector worker %s", engine_id) + + self.parallel_config = MooncakeParallelConfig() + self.is_kv_producer: bool = kv_transfer_config.kv_role == "kv_producer" + self.is_kv_consumer: bool = kv_transfer_config.kv_role == "kv_consumer" + + self.num_sender_workers = kv_transfer_config.kv_connector_extra_config.get( + "num_workers", 10 + ) + # Create more tasks than workers to keep the thread pool saturated. + self.num_sender_tasks = self.num_sender_workers * 2 + protocol = kv_transfer_config.kv_connector_extra_config.get( + "mooncake_protocol", "rdma" + ) + logger.info( + "The Mooncake Transfer Engine is using %s as its protocol.", protocol + ) + + self.engine = TransferEngine() + self.hostname = get_ip() + ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", protocol, "") + if ret_value != 0: + raise RuntimeError("Mooncake Transfer Engine initialization failed.") + + self.rpc_port = self.engine.get_rpc_port() + + logger.debug( + "Mooncake Transfer Engine initialized at %s:%d", + self.hostname, + self.rpc_port, + ) + + self._remote_agents: dict[EngineId, dict[int, dict[int, str]]] = {} + self._pending_bootstrap_queries: dict[str, asyncio.Event] = {} + self.side_channel_port: int = 0 # we will bind it in register_kv_caches() + self.engine_id: EngineId = engine_id + + self.tp_rank = self.parallel_config.tensor_parallel_rank + self.tp_size = self.parallel_config.tensor_parallel_size + + self.num_blocks = 0 + self.kv_caches_base_addr: list[int] = [] + self.device_kv_caches: dict[str, infinicore.Tensor] = {} + + self.async_zmq_ctx = zmq.asyncio.Context() + self._encoder = msgspec.msgpack.Encoder() + self._xfer_meta_decoder = msgspec.msgpack.Decoder(MooncakeXferMetadata) + self._xfer_resp_decoder = msgspec.msgpack.Decoder(MooncakeXferResponse) + + if not self.is_kv_consumer: + # Background threads for sending kvcaches to D. + self._sender_executor = ThreadPoolExecutor( + max_workers=self.num_sender_workers, + thread_name_prefix="infinilm-mooncake-sender", + ) + logger.debug( + "Mooncake Prefiller: use %d workers to send kvcaches", + self.num_sender_workers, + ) + # An asyncio queue to buffer incoming requests for the sender + self.sender_worker_queue = asyncio.Queue[tuple[bytes, bytes]]() + self.sender_loop = asyncio.new_event_loop() + # Background thread for processing new sending requests. + self._sender_listener_t = threading.Thread( + target=self._async_loop, args=(self.sender_loop,), daemon=True + ) + self._sender_listener_t.start() + + # Start bootstrap server on global rank 0. + if should_launch_bootstrap_server(self.parallel_config): + bootstrap_host, bootstrap_port = get_mooncake_bootstrap_addr( + self.parallel_config + ) + self.bootstrap_server = MooncakeBootstrapServer( + bootstrap_host, bootstrap_port + ) + self.bootstrap_server.start() + + if not self.is_kv_producer: + self.receiver_loop = asyncio.new_event_loop() + self._mooncake_receiver_t = threading.Thread( + target=self._async_loop, args=(self.receiver_loop,), daemon=True + ) + self._mooncake_receiver_t.start() + logger.debug("Mooncake Decoder: start receiver thread") + + if self.is_kv_producer: + self.reqs_need_send: dict[TransferId, SendBlockMeta] = {} + self.finished_sending_reqs: set[ReqId] = set() + + self.reqs_need_send_timeout: list[TransferId] = [] + + if self.is_kv_consumer: + self.finished_recving_reqs: set[ReqId] = set() + self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size} + + # collect timeout reqs to recv + self.timeout_reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = {} + # collect xfer failed reqs id + self.xfer_failed_recving_reqs_ids: set[ReqId] = set() + + def _async_loop(self, loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + def __del__(self): + self.shutdown() + + def shutdown(self) -> None: + """Stop ZMQ / threads / bootstrap (idempotent from __del__).""" + self.async_zmq_ctx.term() + if not self.is_kv_consumer: + self._sender_executor.shutdown(wait=False) + if self.sender_loop.is_running(): + self.sender_loop.call_soon_threadsafe(self.sender_loop.stop) + self._sender_listener_t.join() + if ( + should_launch_bootstrap_server(self.parallel_config) + and getattr(self, "bootstrap_server", None) is not None + ): + self.bootstrap_server.shutdown() + if not self.is_kv_producer and self.receiver_loop.is_running(): + self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop) + self._mooncake_receiver_t.join() + + async def _mooncake_sender_listener(self, ready_event: threading.Event): + """ + Background thread that listens for Mooncake requests, dispatches them + to a thread pool, and sends acknowledgments upon completion. + """ + + sock = self.async_zmq_ctx.socket(zmq.ROUTER) + self.side_channel_port = sock.bind_to_random_port(f"tcp://{self.hostname}") + logger.debug( + "Mooncake sender starting listening on path: tcp://%s:%d", + self.hostname, + self.side_channel_port, + ) + + await self.register_worker_with_bootstrap() + + # Create async worker tasks that process items from the queue + sender_tasks = [ + asyncio.create_task(self._sender_worker(sock)) + for _ in range(self.num_sender_tasks) + ] + + ready_event.set() + + try: + while True: + identity, metadata_bytes = await sock.recv_multipart() + logger.debug("ZMQ recv one msg, identity: %s.", identity) + + await self.sender_worker_queue.put((identity, metadata_bytes)) + except zmq.ContextTerminated: + logger.debug("ZMQ context terminated, exiting Mooncake sender thread.") + except Exception as e: + logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e)) + finally: + # Clean up worker tasks + for task in sender_tasks: + task.cancel() + await asyncio.gather(*sender_tasks, return_exceptions=True) + sock.close() + + async def register_worker_with_bootstrap(self): + host, port = get_mooncake_bootstrap_addr(self.parallel_config) + logger.debug( + "Mooncake sender register_worker_with_bootstrap start! host=%s port=%s", + host, + port, + ) + + url = make_zmq_path("http", host, port) + "/register" + worker_addr = make_zmq_path("tcp", self.hostname, self.side_channel_port) + payload = RegisterWorkerPayload( + engine_id=self.engine_id, + dp_rank=0, + tp_rank=self.tp_rank, + pp_rank=0, + addr=worker_addr, + ) + while True: + try: + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload.model_dump()) + response.raise_for_status() + logger.debug("Successfully registered with bootstrap server at %s", url) + break + except httpx.ConnectError: + # Bootstrap server not ready, wait for a while and retry. + logger.debug("Bootstrap server not ready, wait for a while and retry.") + await asyncio.sleep(1) + except Exception as e: + err_msg = ( + e.response.text if isinstance(e, httpx.HTTPStatusError) else str(e) + ) + logger.error( + "Error registering %s with bootstrap server: %s", payload, err_msg + ) + raise + + async def _sender_worker(self, sock: zmq.asyncio.Socket): + while True: + try: + identity, metadata_bytes = await self.sender_worker_queue.get() + try: + metadata = self._xfer_meta_decoder.decode(metadata_bytes) + await self.send_kv_to_decode(identity, sock, metadata) + except Exception as e: + logger.error("Error processing Mooncake xfer request: %s", e) + error_response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.ERROR, msg=str(e) + ) + await sock.send_multipart( + (identity, self._encoder.encode(error_response)) + ) + finally: + self.sender_worker_queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Error in _sender_worker: %s", e) + + async def send_kv_to_decode( + self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata + ): + pending_reqs: dict[ReqId, SendBlockMeta] = {} + remote_tp_ranks = [0] + if self.tp_rank not in remote_tp_ranks: + # This D worker does not pair with the P worker. + raise RuntimeError( + f"MooncakeConnectorWorker: This P tp_rank {self.tp_rank} not match with remote D target ranks {remote_tp_ranks}" + ) + for d_req_id, (transfer_id, _) in meta.req_blocks.items(): + if transfer_id not in self.reqs_need_send: + # This req is not enqueued in P side yet, create it here. + self.reqs_need_send[transfer_id] = SendBlockMeta( + p_req_id="", + transfer_id=transfer_id, + local_block_ids=[], + ready=asyncio.Event(), + ) + send_meta = self.reqs_need_send[transfer_id] + pending_reqs[d_req_id] = send_meta + + async def wait_and_ret( + d_req_id: ReqId, send_meta: SendBlockMeta + ) -> tuple[ReqId, SendBlockMeta]: + await send_meta.ready.wait() + return d_req_id, send_meta + + wait_tasks = [ + asyncio.create_task(wait_and_ret(d_req_id, send_meta)) + for d_req_id, send_meta in pending_reqs.items() + ] + + while wait_tasks: + ABORT_REQUEST_TIMEOUT = 480 + done_tasks, pending_tasks = await asyncio.wait( + wait_tasks, + timeout=ABORT_REQUEST_TIMEOUT, + return_when=asyncio.FIRST_COMPLETED, + ) + + if not done_tasks: + # the tasks of wait_tasks are all timeout + # abort all pending requests. + + for task in pending_tasks: + task.cancel() + + pending_tasks_reqs_ids = list(pending_reqs.keys()) + # self.reqs_need_send_timeout.extend(pending_tasks_reqs_ids) + logger.warning( + "Timeout waiting for P side ready: %s", list(pending_reqs) + ) + + response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.FINISH, + reqs_ids=pending_tasks_reqs_ids, + reqs_statues=[MooncakeXferReqStatus.TIMEOUT] + * len(pending_tasks_reqs_ids), + msg="Timeout waiting for P side ready.", + ) + + await sock.send_multipart((identity, self._encoder.encode(response))) + break + + wait_tasks = list(pending_tasks) + response_status = ( + MooncakeXferResponseStatus.CONTINUE + if wait_tasks + else MooncakeXferResponseStatus.FINISH + ) + ready_reqs: list[tuple[ReqId, SendBlockMeta]] = [] + for task in done_tasks: + d_req_id, send_meta = task.result() + del pending_reqs[d_req_id] + + if send_meta.transfer_id in self.reqs_need_send: + # Mark it sending to avoid expiration. + send_meta.sending += 1 + if not send_meta.need_send: + self.resolve_need_send(send_meta, remote_tp_ranks) + ready_reqs.append((d_req_id, send_meta)) + else: + # Otherwise (expired, very unlikely), just forget it. + logger.warning( + "Request %s expired before sending on P side.", d_req_id + ) + + raise RuntimeError( + f"MooncakeConnectorWorker: Request {d_req_id} expired before sending on P side." + ) + + ( + src_ptrs, + dst_ptrs, + lengths, + mismatch_reqs_ids, + xfer_reqs_ids, + xfer_block_ids, + ) = await self._build_transfer_params(ready_reqs, meta) + + if mismatch_reqs_ids: + response = MooncakeXferResponse( + status=response_status, + reqs_ids=mismatch_reqs_ids, + reqs_statues=[MooncakeXferReqStatus.ADDR_MISMATCH] + * len(mismatch_reqs_ids), + msg="P num blocks less than D", + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + + raise RuntimeError( + f"MooncakeConnectorWorker: Address mismatch for requests {mismatch_reqs_ids}" + ) + + ret_value = 0 + if len(xfer_reqs_ids) > 0 and src_ptrs: + remote_session = f"{meta.remote_hostname}:{meta.remote_port}" + + # wait until return value + kv_flag_addr = meta.kv_flag_addr + ret_value = await self.sender_loop.run_in_executor( + self._sender_executor, + self._send_blocks, + remote_session, + src_ptrs, + dst_ptrs, + lengths, + xfer_reqs_ids, + xfer_block_ids, + kv_flag_addr, + ) + + if ret_value != 0: + # happen error during mooncake transfer + xfer_failed_reqs_ids = [] + for d_req_id, send_meta in ready_reqs: + send_meta.sending -= 1 + xfer_failed_reqs_ids.append(d_req_id) + + # not delete send_meta object in self.reqs_need_send. + # wait until D + + # Do best effort to transfer the remaining reqs. + response = MooncakeXferResponse( + status=response_status, + reqs_ids=xfer_failed_reqs_ids, + reqs_statues=[MooncakeXferReqStatus.XFER_FAIL] + * len(xfer_failed_reqs_ids), + msg=f"Mooncake transfer engine returned {ret_value}", + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + else: + for d_req_id, send_meta in ready_reqs: + # TODO: for heterogeneous TP (one P pairs to multiple D), + # we need to check whether all headers are sent. + # If not, we should set expire_time to normal and skip the below. + send_meta.sending -= 1 + send_meta.sent += 1 + if send_meta.sent == send_meta.need_send: + del self.reqs_need_send[send_meta.transfer_id] + self.finished_sending_reqs.add(send_meta.p_req_id) + + response = MooncakeXferResponse( + status=response_status, + reqs_ids=[d_req_id for d_req_id, _ in ready_reqs], + reqs_statues=[MooncakeXferReqStatus.SUCCESS] * len(ready_reqs), + msg="successfully", + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + + def resolve_need_send(self, send_meta: SendBlockMeta, remote_tp_ranks: list[int]): + # Prepare for heterogeneous TP (one P pairs to multiple D) + send_meta.need_send = len(remote_tp_ranks) + if send_meta.need_send != 1: + logger.error("Mooncake: Heterogeneous TP is not supported yet.") + raise NotImplementedError( + "Mooncake: Heterogeneous TP is not supported yet." + ) + + async def _build_transfer_params( + self, + ready_reqs: list[tuple[ReqId, SendBlockMeta]], + agent_meta: MooncakeXferMetadata, + ) -> tuple[list[int], list[int], list[int], list[ReqId], list[ReqId], list[int]]: + local_base_addr = self.kv_caches_base_addr + remote_base_addr = agent_meta.kv_caches_base_addr + block_len = self.block_len + remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" + + src_ptrs = [] + dst_ptrs = [] + lengths = [] + + mismatch_reqs_ids: list[ReqId] = [] + xfer_reqs_ids: list[ReqId] = [] + xfer_block_ids: list[int] = [] + + for d_req_id, send_meta in ready_reqs: + _, remote_block_ids = agent_meta.req_blocks[d_req_id] + num_remote_blocks = len(remote_block_ids) + if num_remote_blocks == 0: + xfer_reqs_ids.append(d_req_id) + continue + + local_block_ids = send_meta.local_block_ids + # Partial prefix cache hit: just read uncomputed blocks. + num_local_blocks = len(local_block_ids) + if num_local_blocks < num_remote_blocks: + logger.error( + "req %s: local blocks(%d) less than remote blocks(%d)!", + d_req_id, + num_local_blocks, + num_remote_blocks, + ) + mismatch_reqs_ids.append(d_req_id) + continue + if num_local_blocks > num_remote_blocks: + local_block_ids = local_block_ids[-num_remote_blocks:] + + # Group by indices + group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous( + local_block_ids, remote_block_ids + ) + + for local_layer_addr, remote_layer_addr in zip( + local_base_addr, remote_base_addr + ): + for group_local_block_id, group_remote_block_id in zip( + group_local_block_ids, group_remote_block_ids + ): + src_ptrs.append( + local_layer_addr + group_local_block_id[0] * block_len + ) + dst_ptrs.append( + remote_layer_addr + group_remote_block_id[0] * block_len + ) + lengths.append(block_len * len(group_local_block_id)) + + xfer_reqs_ids.append(d_req_id) + xfer_block_ids.extend(remote_block_ids) + logger.debug( + "Calculate kv_caches ptrs for request %s (%d blocks) to %s", + d_req_id, + num_remote_blocks, + remote_session, + ) + + return ( + src_ptrs, + dst_ptrs, + lengths, + mismatch_reqs_ids, + xfer_reqs_ids, + xfer_block_ids, + ) + + def _send_blocks( + self, + remote_session: str, + src_ptrs: list[int], + dst_ptrs: list[int], + lengths: list[int], + xfer_reqs_ids: list[ReqId], + xfer_block_ids: list[int], + kv_flag_addr: list[int], + ) -> int: + logger.debug( + "mooncake engine batch_transfer_sync_write to %s start,xfer_reqs_ids: %s", + remote_session, + ", ".join(xfer_reqs_ids), + ) + + start_time = time.perf_counter() + ret_value = self.engine.batch_transfer_sync_write( + remote_session, src_ptrs, dst_ptrs, lengths + ) + + if ret_value == 0 and len(kv_flag_addr) > 0: + n = len(xfer_block_ids) + flag_src_ptrs = [self.kv_flag_src_ptrs[b] for b in xfer_block_ids] + flag_dst_ptrs = [kv_flag_addr[b] for b in xfer_block_ids] + flag_lengths = [1] * n + + flag_ret_value = self.engine.batch_transfer_sync_write( + remote_session, flag_src_ptrs, flag_dst_ptrs, flag_lengths + ) + + if flag_ret_value != 0: + ret_value = flag_ret_value + + logger.debug( + "mooncake engine sending flag to %s done, ret_value: %s", + remote_session, + flag_ret_value, + ) + + if ret_value == 0: + logger.debug( + "mooncake engine sending to %s done, took %s, xfer_reqs_ids: %s", + remote_session, + time.perf_counter() - start_time, + ", ".join(xfer_reqs_ids), + ) + else: + logger.debug( + "mooncake engine sending to %s failed, ret_value: %s, xfer_reqs_ids: %s", + remote_session, + ret_value, + ", ".join(xfer_reqs_ids), + ) + return ret_value + + def register_kv_caches(self, kv_caches: dict[str, infinicore.Tensor]): + """Register the KV Cache data in mooncake.""" + + logger.info("Registering KV_Caches.") + + kv_data_ptrs = [] + kv_data_lens = [] + seen_base_addresses = [] + + split_k_and_v = True + tensor_size_bytes = None + for layer_name, cache_or_caches in kv_caches.items(): + logger.debug( + "registering layer %s with shape %s", layer_name, cache_or_caches.shape + ) + + assert split_k_and_v, "split_k_and_v must be True" + cache_list = [ + cache_or_caches.narrow(0, 0, 1).squeeze(0), # k_cache + cache_or_caches.narrow(0, 1, 1).squeeze(0), # v_cache + ] + + for cache in cache_list: + base_addr = cache.data_ptr() + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + + if True: + if cache.dtype == infinicore.bfloat16: + dtype_size = 2 + else: + raise ValueError(f"Unsupported dtype: {cache.dtype}") + + numel = cache.numel() + curr_tensor_size_bytes = numel * dtype_size + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert tensor_size_bytes == curr_tensor_size_bytes, ( + "All kv cache tensors must have the same size" + ) + + kv_data_ptrs.append(base_addr) + kv_data_lens.append(tensor_size_bytes) + + self.kv_caches_base_addr = seen_base_addresses + + ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens) + if ret_value != 0: + raise RuntimeError("Mooncake batch memory registration failed.") + + assert tensor_size_bytes is not None + assert self.num_blocks != 0 + assert tensor_size_bytes % self.num_blocks == 0 + self.block_len = tensor_size_bytes // self.num_blocks + self.device_kv_caches = kv_caches + logger.info( + "registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len + ) + # logger.info("registered kv_caches_base_addr=%d", len(self.kv_caches_base_addr)) + + if self.is_kv_consumer: + self.kv_flag = np.zeros(self.num_blocks, dtype=np.uint8) + flag_base_ptr = self.kv_flag.ctypes.data + self.kv_flag_ptrs = [flag_base_ptr + i for i in range(self.num_blocks)] + self.kv_flag_lens = [1] * self.num_blocks + + ret = self.engine.batch_register_memory( + self.kv_flag_ptrs, self.kv_flag_lens + ) + + if ret != 0: + raise RuntimeError( + "Mooncake batch memory registration failed for kv block flags." + ) + + if self.is_kv_producer: + self.kv_flag_src = np.ones(self.num_blocks, dtype=np.uint8) + flag_src_base_ptr = self.kv_flag_src.ctypes.data + self.kv_flag_src_ptrs = [ + flag_src_base_ptr + i for i in range(self.num_blocks) + ] + self.kv_flag_src_lens = [1] * self.num_blocks + + ret = self.engine.batch_register_memory( + self.kv_flag_src_ptrs, self.kv_flag_src_lens + ) + if ret != 0: + raise RuntimeError( + "Mooncake batch memory registration failed for kv block src flag." + ) + + # No need to launch server for D node. + if self.is_kv_consumer: + return + + ready_event = threading.Event() + asyncio.run_coroutine_threadsafe( + self._mooncake_sender_listener(ready_event), self.sender_loop + ) + ready_event.wait() # Wait for listener ZMQ socket to be ready. + + async def fetch_finished_recving_reqs(self) -> set[ReqId]: + finished_recving_reqs = self.finished_recving_reqs + self.finished_recving_reqs = set() + return finished_recving_reqs + + async def fetch_xfer_failed_recving_reqs(self) -> set[ReqId]: + xfer_failed_recving_reqs_ids = self.xfer_failed_recving_reqs_ids + self.xfer_failed_recving_reqs_ids = set() + return xfer_failed_recving_reqs_ids + + async def fetch_finished_sending_reqs(self) -> set[ReqId]: + finished_sending_reqs = self.finished_sending_reqs + self.finished_sending_reqs = set() + + # Handle timeout to avoid stranding blocks on remote. + now = time.perf_counter() + + for transfer_id, send_meta in self.reqs_need_send.items(): + if ( + send_meta.p_req_id + and send_meta.expire_time < now + and send_meta.sending == 0 + ): + logger.warning( + "Request %s timed out after %d seconds without " + "being sent. don't freeing its blocks on the producer side.", + send_meta.p_req_id, + 480, + ) + + # reset time + send_meta.expire_time = time.perf_counter() + 480 + + # TODO: mv timeout reqs to finished_sending_reqs set + finished_sending_reqs.add(send_meta.p_req_id) + + return finished_sending_reqs + + def get_finished(self) -> tuple[set[str] | None, set[str] | None, set[str] | None]: + """ + Get requests that are done sending or recving on this specific worker. + The scheduler process will use this output to track which workers are done. + """ + recv_fut = None + failed_recv_fut = None + send_fut = None + if not self.is_kv_producer: + recv_fut = asyncio.run_coroutine_threadsafe( + self.fetch_finished_recving_reqs(), self.receiver_loop + ) + failed_recv_fut = asyncio.run_coroutine_threadsafe( + self.fetch_xfer_failed_recving_reqs(), self.receiver_loop + ) + + if not self.is_kv_consumer: + send_fut = asyncio.run_coroutine_threadsafe( + self.fetch_finished_sending_reqs(), self.sender_loop + ) + + finished_recving_reqs = recv_fut.result() if recv_fut else set() + failed_recving_reqs = failed_recv_fut.result() if failed_recv_fut else set() + finished_sending_reqs = send_fut.result() if send_fut else set() + + if finished_sending_reqs or finished_recving_reqs: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", + self.tp_rank, + len(finished_sending_reqs), + len(finished_recving_reqs), + ) + + return ( + finished_sending_reqs or None, + failed_recving_reqs or None, + finished_recving_reqs or None, + ) + + async def _wait_for_kv_flags_ready(self, block_ids: list[int]) -> None: + """Wait until Mooncake sets all kv_flag entries for the given blocks to 1.""" + if not block_ids: + return + + indices = np.asarray(block_ids, dtype=np.intp) + while not np.all(self.kv_flag[indices] == 1): + await asyncio.sleep(0.012) + + self.kv_flag[indices] = 0 + + async def receive_kv_from_single_worker( + self, + remote_engine_id, + worker_addr: str, + pull_metas: dict[ReqId, PullReqMeta], + ): + req_ids = set(pull_metas) + metadata = MooncakeXferMetadata( + remote_hostname=self.hostname, + remote_port=self.rpc_port, + remote_tp_size=self.tp_size, + remote_tp_rank=self.tp_rank, + req_blocks={ + req_id: (pull_meta.transfer_id, pull_meta.local_block_ids) + for req_id, pull_meta in pull_metas.items() + }, + kv_caches_base_addr=self.kv_caches_base_addr, + kv_flag_addr=self.kv_flag_ptrs, + ) + + encoded_data = self._encoder.encode(metadata) + logger.debug( + "Size of encoded MooncakeXferMetadata: %d bytes", len(encoded_data) + ) + logger.debug( + "Sending kv transfer request for %s on path: %s", req_ids, worker_addr + ) + + # Send query for the request. + try: + with make_zmq_socket( + self.async_zmq_ctx, worker_addr, zmq.DEALER, bind=False, linger=0 + ) as sock: + # If something goes wrong, let P wait timeout first (in asyncio.wait()). + sock.setsockopt(zmq.RCVTIMEO, (480 + 60) * 1000) + await sock.send(encoded_data) + + response_list = [] + while True: + ret_msg = await sock.recv() + response = self._xfer_resp_decoder.decode(ret_msg) + response_list.append(response) + + # zmq exception happens + if response.status == MooncakeXferResponseStatus.ERROR: + logger.error( + "Error happens during transferring kvcache for %s: %s", + req_ids, + response.msg, + ) + raise RuntimeError( + f"MooncakeConnectorWorker: recv response is Error happens during transferring kvcache for {req_ids}: {response.msg}" + ) + + if response.status == MooncakeXferResponseStatus.FINISH: + break + + # process response list + processed_reqs_count = 0 + success_block_ids = [] + finished_recving_reqs = set() + for response in response_list: + reqs_count, finished_reqs, block_ids = self.process_pulling_result( + remote_engine_id, response, pull_metas + ) + processed_reqs_count += reqs_count + success_block_ids.extend(block_ids) + finished_recving_reqs.update(finished_reqs) + + # TODO:check if all reqs are processed + assert processed_reqs_count == len(pull_metas), ( + "processed_reqs_count must be equal to the number of pull_metas" + ) + + await self._wait_for_kv_flags_ready(success_block_ids) + + self.finished_recving_reqs.update(finished_recving_reqs) + + except zmq.ContextTerminated: + logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.") + # TODO: handle this error + except Exception as e: + logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e) + return + + def process_pulling_result( + self, + remote_engine_id: EngineId, + response: MooncakeXferResponse, + pull_metas: dict[ReqId, PullReqMeta], + ): + response_reqs_ids = response.reqs_ids or [] + response_reqs_statues = response.reqs_statues or [] + + reqs_count_of_response = len(response_reqs_ids) + + assert reqs_count_of_response == len(response_reqs_statues), ( + "response_reqs_ids and response_reqs_statues must have the same count" + ) + + success_reqs_ids = [] + timeout_reqs_ids = [] + addr_mismatch_reqs_ids = [] + xfer_failed_reqs_ids = [] + for req_id, status in zip(response_reqs_ids, response_reqs_statues): + match status: + case MooncakeXferReqStatus.SUCCESS: + success_reqs_ids.append(req_id) + case MooncakeXferReqStatus.TIMEOUT: + timeout_reqs_ids.append(req_id) + case MooncakeXferReqStatus.ADDR_MISMATCH: + addr_mismatch_reqs_ids.append(req_id) + case MooncakeXferReqStatus.XFER_FAIL: + xfer_failed_reqs_ids.append(req_id) + case _: + raise ValueError( + f"MooncakeConnectorWorker: Invalid status {status} for request {req_id}" + ) + success_block_ids = [] + if len(success_reqs_ids) > 0: + for req_id in success_reqs_ids: + pull_meta = pull_metas[req_id] + + success_block_ids.extend(pull_meta.local_block_ids) + + # No race because we are in async loop. + pull_meta.pull_tasks_count -= 1 + if pull_meta.pull_tasks_count == 0: + assert req_id == pull_meta.d_req_id + else: + raise RuntimeError( + f"MooncakeConnectorWorker: Pull tasks count is not 0 for request {req_id}" + ) + logger.debug("successfully pulling kv_caches for %s", success_reqs_ids) + + if timeout_reqs_ids: + inner = self.timeout_reqs_to_recv.setdefault(remote_engine_id, {}) + for ( + req_id + ) in timeout_reqs_ids: # D 侧收集超时,下一拍在 _start_load_kv 合并重试 + pull_meta = pull_metas[req_id] + pull_meta.pull_tasks_count = 0 + inner[req_id] = pull_meta + + if len(addr_mismatch_reqs_ids) > 0: + raise RuntimeError( + f"MooncakeConnectorWorker: Address mismatch for requests {addr_mismatch_reqs_ids}" + ) + + if len(xfer_failed_reqs_ids) > 0: + logger.error( + "MooncakeConnectorWorker: pulling kv_caches for %s failed: %s", + xfer_failed_reqs_ids, + response.msg, + ) + self.xfer_failed_recving_reqs_ids.update(xfer_failed_reqs_ids) + + finished_recving_reqs = set(success_reqs_ids) | set(xfer_failed_reqs_ids) + + return reqs_count_of_response, finished_recving_reqs, success_block_ids + + async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str): + url = remote_bootstrap_addr + "/query" + try: + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + data: dict = response.json() + for _, dp_entry in data.items(): + remote_engine_id = dp_entry["engine_id"] + self._remote_agents[remote_engine_id] = { + int(tp_rank): { + int(pp_rank): worker_addr + for pp_rank, worker_addr in tp_entry.items() + } + for tp_rank, tp_entry in dp_entry["worker_addr"].items() + } + self._tp_size[remote_engine_id] = len(dp_entry["worker_addr"]) + except Exception as e: + logger.error( + "Failed to connect to bootstrap server %s: %s", + remote_bootstrap_addr, + e, + ) + + # Always notify others regardless of connection success or failure. + self._pending_bootstrap_queries[remote_bootstrap_addr].set() + del self._pending_bootstrap_queries[remote_bootstrap_addr] + + def receive_kv( + self, + remote_engine_id: EngineId, + pull_metas: dict[ReqId, PullReqMeta], + ): + remote_tp_ranks = [0] + count = len(remote_tp_ranks) + if count != 1: + logger.error("Mooncake: Heterogeneous TP is not supported yet.") + raise NotImplementedError( + "Mooncake: Heterogeneous TP is not supported yet." + ) + for pull_meta in pull_metas.values(): + pull_meta.pull_tasks_count = count + for remote_tp_rank in remote_tp_ranks: + worker_addr = self._remote_agents[remote_engine_id][remote_tp_rank][0] + asyncio.create_task( + self.receive_kv_from_single_worker( + remote_engine_id, worker_addr, pull_metas + ) + ) + + async def handle_new_engine_id( + self, + remote_engine_id: EngineId, + pull_metas: dict[ReqId, PullReqMeta], + ): + remote_bootstrap_addr = next(iter(pull_metas.values())).remote_bootstrap_addr + if remote_bootstrap_addr not in self._pending_bootstrap_queries: + self._pending_bootstrap_queries[remote_bootstrap_addr] = asyncio.Event() + await self._connect_to_prefiller_bootstrap(remote_bootstrap_addr) + else: + await self._pending_bootstrap_queries[remote_bootstrap_addr].wait() + + if remote_engine_id not in self._remote_agents: + logger.error( + "Failed to find remote engine_id %s from bootstrap server %s", + remote_engine_id, + remote_bootstrap_addr, + ) + return + + self.receive_kv(remote_engine_id, pull_metas) + + async def _start_load_kv( + self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] + ): + # reprocess timeout reqs (merge per-engine pull maps, do not replace whole inner dicts) + if self.timeout_reqs_to_recv: + for engine_id, timed_out in self.timeout_reqs_to_recv.items(): + inner = reqs_to_recv.setdefault(engine_id, {}) + inner.update(timed_out) + self.timeout_reqs_to_recv.clear() + + for remote_engine_id, pull_metas in reqs_to_recv.items(): + if remote_engine_id not in self._remote_agents: + asyncio.create_task( + self.handle_new_engine_id(remote_engine_id, pull_metas) + ) + else: + self.receive_kv(remote_engine_id, pull_metas) + + async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): + for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items(): + if block_ids: + # Already gone through request_finished() + send_meta = self.reqs_need_send[transfer_id] + send_meta.p_req_id = p_req_id + send_meta.local_block_ids = block_ids + send_meta.expire_time = time.perf_counter() + 480 + send_meta.ready.set() + else: + # From update_state_after_alloc(), + # but not reach request_finished() yet + # This may be already created by send_kv_to_decode() + # when D is sending MooncakeXferMetadata. + if transfer_id not in self.reqs_need_send: + self.reqs_need_send[transfer_id] = SendBlockMeta( + p_req_id=p_req_id, + transfer_id=transfer_id, + local_block_ids=[], + ready=asyncio.Event(), + ) + for transfer_id in metadata.reqs_not_processed: + send_meta = self.reqs_need_send.pop(transfer_id) + if send_meta: + assert not send_meta.ready.is_set() + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + if not self.is_kv_producer and metadata.reqs_to_recv: + asyncio.run_coroutine_threadsafe( + self._start_load_kv(metadata.reqs_to_recv), self.receiver_loop + ) + + if not self.is_kv_consumer and ( + metadata.reqs_to_send or metadata.reqs_not_processed + ): + asyncio.run_coroutine_threadsafe( + self.record_send_reqs(metadata), self.sender_loop + ) + + +def group_concurrent_contiguous( + src_indices: list[int], dst_indices: list[int] +) -> tuple[list[list[int]], list[list[int]]]: + """Group parallel src/dst index lists into contiguous runs (NumPy).""" + if len(src_indices) == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + return [g.tolist() for g in src_groups], [g.tolist() for g in dst_groups] diff --git a/python/infinilm/kv_connector/mooncake/mooncake_utils.py b/python/infinilm/kv_connector/mooncake/mooncake_utils.py new file mode 100644 index 00000000..626efeb0 --- /dev/null +++ b/python/infinilm/kv_connector/mooncake/mooncake_utils.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright 2026 InfiniLM Contributors + +from __future__ import annotations + +import ipaddress +import logging +import os +import threading +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + + +import psutil +import uvicorn +import zmq +import zmq.asyncio +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from urllib3.util import parse_url + +if TYPE_CHECKING: + from infinilm.kv_connector.mooncake.mooncake_connector_worker import ( + MooncakeParallelConfig, + ) + + +logger = logging.getLogger(__name__) + +ReqId = str +TransferId = str +EngineId = str +WorkerAddr = str + + +class RegisterWorkerPayload(BaseModel): + engine_id: EngineId + dp_rank: int + tp_rank: int + pp_rank: int + addr: WorkerAddr + + +@dataclass +class EngineEntry: + engine_id: EngineId + # {tp_rank: {pp_rank: worker_addr}} + worker_addr: dict[int, dict[int, WorkerAddr]] + + +def should_launch_bootstrap_server(parallel_config: "MooncakeParallelConfig") -> bool: + assert parallel_config + return 0 == parallel_config.tensor_parallel_rank + + +def get_mooncake_bootstrap_addr( + parallel_config: "MooncakeParallelConfig", +) -> tuple[str, int]: + """ + Returns the address of the Mooncake bootstrap server. + This is only used by prefillers to register workers. + Decoders should get addr from kv_transfer_params. + """ + assert parallel_config + + # Port and Host used for Mooncake handshake between remote agents. + mooncake_bootstrap_host = str( + os.getenv("INFINILM_MOONCAKE_BOOTSTRAP_HOST", "127.0.0.1") + ) + mooncake_bootstrap_port = int(os.getenv("INFINILM_MOONCAKE_BOOTSTRAP_PORT", "8998")) + + return (mooncake_bootstrap_host, mooncake_bootstrap_port) + + +class MooncakeBootstrapServer: + """ + A centralized server running on the global rank 0 prefiller worker. + Prefiller workers register their connection info (IP, port, ranks) here. + """ + + def __init__(self, host: str, port: int): + self.workers: dict[int, EngineEntry] = {} + self.host = host + self.port = port + self.app = FastAPI() + self._register_routes() + self.server_thread: threading.Thread | None = None + self.server: uvicorn.Server | None = None + + def __del__(self): + self.shutdown() + + def _register_routes(self): + # All methods are async. No need to use lock to protect data. + self.app.post("/register")(self.register_worker) + self.app.get("/query", response_model=dict[int, EngineEntry])(self.query) + + def start(self): + if self.server_thread: + return + + logger.info("Mooncake Bootstrap Server is starting ......") + config = uvicorn.Config(app=self.app, host=self.host, port=self.port) + self.server = uvicorn.Server(config=config) + self.server_thread = threading.Thread( + target=self.server.run, name="mooncake_bootstrap_server", daemon=True + ) + self.server_thread.start() + while not self.server.started: + time.sleep(0.1) # Wait for the server to start + logger.info("Mooncake Bootstrap Server started at %s:%d", self.host, self.port) + + def shutdown(self): + if self.server_thread is None or self.server is None or not self.server.started: + return + + self.server.should_exit = True + self.server_thread.join() + logger.info("Mooncake Bootstrap Server stopped.") + + async def register_worker(self, payload: RegisterWorkerPayload): + """Handles registration of a prefiller worker.""" + if payload.dp_rank not in self.workers: + self.workers[payload.dp_rank] = EngineEntry( + engine_id=payload.engine_id, + worker_addr={}, + ) + + dp_entry = self.workers[payload.dp_rank] + if dp_entry.engine_id != payload.engine_id: + raise HTTPException( + status_code=400, + detail=( + f"Engine ID mismatch for dp_rank={payload.dp_rank}: " + f"expected {dp_entry.engine_id}, got {payload.engine_id}" + ), + ) + if payload.tp_rank not in dp_entry.worker_addr: + dp_entry.worker_addr[payload.tp_rank] = {} + + tp_entry = dp_entry.worker_addr[payload.tp_rank] + if payload.pp_rank in tp_entry: + raise HTTPException( + status_code=400, + detail=( + f"Worker with dp_rank={payload.dp_rank}, " + f"tp_rank={payload.tp_rank}, pp_rank={payload.pp_rank} " + f"is already registered at " + f"{tp_entry[payload.pp_rank]}, " + f"but still want to register at {payload.addr}" + ), + ) + + tp_entry[payload.pp_rank] = payload.addr + logger.debug( + "Registered worker: engine_id=%s, dp_rank=%d, tp_rank=%d, pp_rank=%d at %s", + payload.engine_id, + payload.dp_rank, + payload.tp_rank, + payload.pp_rank, + payload.addr, + ) + + return {"status": "ok"} + + async def query(self) -> dict[int, EngineEntry]: + return self.workers + + +def get_ip() -> str: + host_ip = os.getenv("INFINILM_HOST_IP", "127.0.0.1") + logger.info("INFINILM_HOST_IP is %s", host_ip) + return host_ip + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str: + """Make a ZMQ path from its parts. + + Args: + scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc). + host: The host - can be an IPv4 address, IPv6 address, or hostname. + port: Optional port number, only used for TCP sockets. + + Returns: + A properly formatted ZMQ path string. + """ + if port is None: + return f"{scheme}://{host}" + if is_valid_ipv6_address(host): + return f"{scheme}://[{host}]:{port}" + return f"{scheme}://{host}:{port}" + + +def split_zmq_path(path: str) -> tuple[str, str, str]: + """Split a zmq path into its parts.""" + + parsed = parse_url(path) + + if not parsed.scheme: + raise ValueError(f"Invalid zmq path: {path}") + + scheme = parsed.scheme + host = parsed.hostname or "" + port = str(parsed.port or "") + if host.startswith("[") and host.endswith("]"): + host = host[1:-1] # Remove brackets for IPv6 address + + if scheme == "tcp" and not all((host, port)): + # The host and port fields are required for tcp + raise ValueError(f"Invalid zmq path: {path}") + + if scheme != "tcp" and port: + # port only makes sense with tcp + raise ValueError(f"Invalid zmq path: {path}") + + return scheme, host, port + + +def make_zmq_socket( + ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined] + path: str, + socket_type: Any, + bind: bool | None = None, + identity: bytes | None = None, + linger: int | None = None, + router_handover: bool = False, +) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined] + """Make a ZMQ socket with the proper bind/connect semantics.""" + + mem = psutil.virtual_memory() + socket = ctx.socket(socket_type) + + # Calculate buffer size based on system memory + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + # For systems with substantial memory (>32GB total, >16GB available): + # - Set a large 0.5GB buffer to improve throughput + # For systems with less memory: + # - Use system default (-1) to avoid excessive memory consumption + buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1 + + if bind is None: + bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) + + if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, buf_size) + + if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, buf_size) + + if socket_type == zmq.ROUTER and router_handover: + # Let a new connection take over an identity left behind by a dead one. + socket.setsockopt(zmq.ROUTER_HANDOVER, 1) + + if identity is not None: + socket.setsockopt(zmq.IDENTITY, identity) + + if linger is not None: + socket.setsockopt(zmq.LINGER, linger) + + if socket_type == zmq.XPUB: + socket.setsockopt(zmq.XPUB_VERBOSE, True) + + # Determine if the path is a TCP socket with an IPv6 address. + # Enable IPv6 on the zmq socket if so. + scheme, host, _ = split_zmq_path(path) + if scheme == "tcp" and is_valid_ipv6_address(host): + socket.setsockopt(zmq.IPV6, 1) + + if bind: + socket.bind(path) + else: + socket.connect(path) + + return socket diff --git a/python/infinilm/llm/__init__.py b/python/infinilm/llm/__init__.py index e0fd6095..675e72b6 100644 --- a/python/infinilm/llm/__init__.py +++ b/python/infinilm/llm/__init__.py @@ -15,7 +15,6 @@ LLM, LLMEngine, AsyncLLMEngine, - EngineConfig, ) from infinilm.llm.scheduler import Scheduler, SchedulerOutput from infinilm.llm.static_scheduler import StaticScheduler, StaticSchedulerOutput @@ -26,7 +25,6 @@ "LLM", "AsyncLLMEngine", "LLMEngine", - "EngineConfig", # Parameters "SamplingParams", # Request and Output diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index ea857b02..8e8007cd 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -3,7 +3,6 @@ """ from collections import deque -import queue from typing import List, Dict, Set import xxhash import numpy as np @@ -18,6 +17,9 @@ def __init__(self, block_id: int): self.hash = -1 self.token_ids: List[int] = [] + def __repr__(self) -> str: + return f"Block(id={self.block_id}, ref={self.ref_count}, hash={self.hash})" + def update(self, hash_value: int, token_ids: List[int]) -> None: self.hash = hash_value self.token_ids = token_ids.copy() @@ -32,9 +34,6 @@ def free(self) -> None: self.hash = -1 self.token_ids = [] - def __repr__(self) -> str: - return f"Block(id={self.block_id}, ref={self.ref_count}, hash={self.hash})" - class BlockManager: """Manages Paged KV Cache allocation with prefix caching support. @@ -45,28 +44,6 @@ class BlockManager: - Slot mapping generation for physical-to-logical position mapping """ - def __init__(self, num_blocks: int, block_size: int): - assert num_blocks > 0 and block_size > 0, ( - "num_blocks and block_size must be positive" - ) - self.num_blocks = num_blocks - self.block_size = block_size - - self.blocks: List[Block] = [Block(i) for i in range(num_blocks)] - self.hash_to_block_id: Dict[int, int] = {} - self.free_block_ids: deque = deque(range(num_blocks)) - self.used_block_ids: Set[int] = set() - self.req_block_ids: Set[int] = set() - - def reset_req_blocks(self) -> None: - """Move blocks from prefill stage to used blocks and update hash mappings.""" - for block_id in self.req_block_ids: - self.used_block_ids.add(block_id) - block = self.blocks[block_id] - prefix_hash = block.hash - self.hash_to_block_id[prefix_hash] = block_id - self.req_block_ids.clear() - @classmethod def compute_hash( cls, @@ -84,26 +61,45 @@ def compute_hash( h.update(identifier.encode("utf-8")) return h.intdigest() - def _allocate_partial_block(self, block_id: int) -> Block: - """Allocate an incomplete block and add to used blocks.""" - assert block_id in self.free_block_ids, f"Block {block_id} not in free list" + def __init__(self, num_blocks: int, block_size: int): + assert num_blocks > 0 and block_size > 0, ( + "num_blocks and block_size must be positive" + ) + self.num_blocks = num_blocks + self.block_size = block_size + + self.blocks: List[Block] = [Block(i) for i in range(num_blocks)] + self.hash_to_block_id: Dict[int, int] = {} + self.free_block_ids: deque = deque(range(num_blocks)) + self.used_block_ids: Set[int] = set() + self.pending_block_ids: Set[int] = set() + + def __repr__(self): + return ( + f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, " + f"free={len(self.free_block_ids)}, used={len(self.used_block_ids)})" + ) + + # Private low-level operations + + def _allocate_partial_block(self) -> Block: + """Pop the first free block and add it to used blocks as a partial block.""" + block_id = self.free_block_ids.popleft() block = self.blocks[block_id] assert block.ref_count == 0, f"Block {block_id} ref_count not zero" block.reset() - self.free_block_ids.remove(block_id) self.used_block_ids.add(block_id) return block - def _allocate_full_block(self, block_id: int) -> Block: - """Allocate a complete block and add to request blocks.""" - assert block_id in self.free_block_ids, f"Block {block_id} not in free list" + def _allocate_full_block(self) -> Block: + """Pop the first free block and add it to pending blocks as a full block.""" + block_id = self.free_block_ids.popleft() block = self.blocks[block_id] assert block.ref_count == 0, f"Block {block_id} ref_count not zero" block.reset() - self.free_block_ids.remove(block_id) - self.req_block_ids.add(block_id) + self.pending_block_ids.add(block_id) return block def _deallocate_block(self, block_id: int): @@ -120,43 +116,62 @@ def _deallocate_block(self, block_id: int): self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) + def _commit_pending_blocks(self) -> None: + """Commit pending prefill blocks into used_block_ids and register their hashes.""" + for block_id in self.pending_block_ids: + self.used_block_ids.add(block_id) + block = self.blocks[block_id] + if block.hash != -1: + self.hash_to_block_id[block.hash] = block_id + self.pending_block_ids.clear() + + # Read-only state queries + def can_allocate(self, num_required_blocks: int) -> bool: return len(self.free_block_ids) >= num_required_blocks - def allocate_blocks( + def get_num_free_blocks(self) -> int: + return len(self.free_block_ids) + + def get_total_usable_blocks(self) -> int: + freeable_used_blocks = sum( + 1 for bid in self.used_block_ids if self.blocks[bid].ref_count == 0 + ) + return len(self.free_block_ids) + freeable_used_blocks + + # Core public operations + + def get_computed_blocks( self, token_ids: List[int], - block_table: List[int] = None, mm_token_index_mappings: List[dict] = None, - ) -> tuple[List[int], List[int], int]: - """Allocate cache blocks for new request with prefix caching support. + ) -> tuple[List[int], int, List[dict]]: + """Find locally cached prefix blocks for the given token sequence. + + The last token is never matched, as it must be recomputed to obtain logits. Args: - token_ids: Input token sequence - block_table: Existing block_table (for decode phase) - mm_token_index_mappings: List of multimodal token index mappings + token_ids: Input token sequence. + mm_token_index_mappings: List of multimodal token index mappings. Returns: - Tuple of (block_table, slot_mapping, num_cached_tokens) + A tuple of (cached_block_table, num_local_cached_tokens, blocks_blueprint): + - cached_block_table: List of matched block IDs (each with ref_count incremented). + - num_local_cached_tokens: Number of cached tokens (always a multiple of block_size). + - blocks_blueprint: Per-block cached block id and precomputed prefix hash. """ - if block_table is None: - block_table = [] - - # Static args num_tokens = len(token_ids) num_blocks = (num_tokens + self.block_size - 1) // self.block_size num_full_blocks = num_tokens // self.block_size remain_tokens = num_tokens % self.block_size - num_mm_inputs = ( - 0 if not mm_token_index_mappings else len(mm_token_index_mappings) - ) + mm_token_index_mappings = mm_token_index_mappings or [] + num_mm_inputs = len(mm_token_index_mappings) # Variables - slot_mapping = [] - num_cached_tokens = 0 + cached_block_table = [] prefix_hash = -1 cache_miss = False mm_start_counter = 0 - mm_caching_queue = queue.Queue(maxsize=len(mm_token_index_mappings)) + mm_caching_queue = deque() blocks_blueprint = [] # [{"prefix_hash": int or -1 if not a full block, "block_id": int or -1 if not cached}, ...] max_blocks_to_reuse = num_full_blocks @@ -177,7 +192,7 @@ def allocate_blocks( mm_data_identifiers.append( mm_token_index_mappings[mm_start_counter]["identifier"] ) - mm_caching_queue.put((mm_start_counter)) + mm_caching_queue.append(mm_start_counter) mm_start_counter += 1 prefix_hash = ( @@ -207,19 +222,19 @@ def allocate_blocks( if not cache_miss: # pop fully cached mm_data while ( - not mm_caching_queue.empty() - and mm_token_index_mappings[mm_caching_queue.queue[0]]["end_index"] + mm_caching_queue + and mm_token_index_mappings[mm_caching_queue[0]]["end_index"] < end_idx ): - mm_caching_queue.get() + mm_caching_queue.popleft() blocks_blueprint.append( {"prefix_hash": prefix_hash, "block_id": cached_block_id} ) # If there is one incomplete mm_data, tailing blocks need to fall back until all included mm_data are complete - if not mm_caching_queue.empty(): - incomplete_mm = mm_token_index_mappings[mm_caching_queue.get()] + if mm_caching_queue: + incomplete_mm = mm_token_index_mappings[mm_caching_queue.popleft()] incomplete_mm_start = incomplete_mm[ "start_index" ] # Fall back until this index is no longer included in the block @@ -227,41 +242,100 @@ def allocate_blocks( max_blocks_to_reuse, incomplete_mm_start // self.block_size ) - num_cached_tokens = max_blocks_to_reuse * self.block_size + num_local_cached_tokens = max_blocks_to_reuse * self.block_size + + for block_id in range(max_blocks_to_reuse): + block = self.blocks[blocks_blueprint[block_id]["block_id"]] + block.ref_count += 1 + cached_block_table.append(block.block_id) + + return cached_block_table, num_local_cached_tokens, blocks_blueprint + + def allocate_slots( + self, + token_ids: List[int], + num_new_tokens: int, + num_computed_tokens: int = 0, + cached_block_table: List[int] = None, + blocks_blueprint: List[dict] = None, + delay_cache_blocks: bool = False, + ) -> tuple[List[int], List[int]] | None: + """Allocate KV cache slots for a request (PD-disaggregation aware). + + Note: Requires that the underlying attention kernel writes KV cache before + reading it (write-before-read ordering). + + Args: + token_ids: Complete token sequence for the request. + num_new_tokens: Number of tokens to compute in this step. + num_computed_tokens: Total number of tokens already computed across local and remote workers. + cached_block_table: Already-matched local block IDs. + blocks_blueprint: Per-block precomputed prefix hashes from get_computed_blocks. + delay_cache_blocks: When True (async PD transfer in progress), allocate + blocks but defer hash registration until transfer completes. - for block_id in range(num_blocks): - n_block_tokens = self.block_size + Returns: + A tuple of (block_table, slot_mapping), or None if blocks are insufficient. + - block_table: Full block list. + - slot_mapping: Physical slot IDs for the tokens that need to be computed. + """ + if cached_block_table is None: + cached_block_table = [] + block_table = list(cached_block_table) + slot_mapping = [] - if block_id < max_blocks_to_reuse: - # Reuse block - block = self.blocks[blocks_blueprint[block_id]["block_id"]] - block.ref_count += 1 + total_tokens = num_computed_tokens + num_new_tokens + num_blocks_needed = ( + total_tokens + self.block_size - 1 + ) // self.block_size - len(cached_block_table) + + if not self.can_allocate(num_blocks_needed): + if not self.try_free_blocks(num_blocks_needed): + return None + + start_block_idx = len(cached_block_table) + total_blocks = (total_tokens + self.block_size - 1) // self.block_size + prefix_hash = ( + self.blocks[cached_block_table[-1]].hash if cached_block_table else -1 + ) + + for block_idx in range(start_block_idx, total_blocks): + start_tok = block_idx * self.block_size + end_tok = min(start_tok + self.block_size, len(token_ids)) + block_tokens = token_ids[start_tok:end_tok] + is_full_block = len(block_tokens) == self.block_size + + if not self.free_block_ids: + return None + + if is_full_block: + block_hash = -1 + if blocks_blueprint is not None and block_idx < len(blocks_blueprint): + block_hash = blocks_blueprint[block_idx]["prefix_hash"] + if block_hash == -1: + block_hash = self.compute_hash(block_tokens, prefix_hash) + prefix_hash = block_hash + block = self._allocate_full_block() + block.update(block_hash, block_tokens) else: - new_block_id = self.free_block_ids[0] - if blocks_blueprint[block_id]["prefix_hash"] != -1: - start_idx = block_id * self.block_size - end_idx = start_idx + self.block_size - block_tokens = token_ids[start_idx:end_idx] - block = self._allocate_full_block(new_block_id) - block.update( - blocks_blueprint[block_id]["prefix_hash"], block_tokens - ) - else: - block = self._allocate_partial_block(new_block_id) - n_block_tokens = remain_tokens - slot_mapping.extend( - list( - range( - block.block_id * self.block_size, - block.block_id * self.block_size + n_block_tokens, - ) - ) - ) + block = self._allocate_partial_block() block_table.append(block.block_id) - return block_table, slot_mapping, num_cached_tokens + for tok_idx in range(num_computed_tokens, total_tokens): + blk_idx = tok_idx // self.block_size + blk_offset = tok_idx % self.block_size + slot_mapping.append(block_table[blk_idx] * self.block_size + blk_offset) + + if delay_cache_blocks: + for block_id in list(self.pending_block_ids): + self.used_block_ids.add(block_id) + self.pending_block_ids.clear() + else: + self._commit_pending_blocks() + + return block_table, slot_mapping def append_slot( self, block_table: List[int], num_tokens: int, total_token_ids: List[int] = None @@ -305,9 +379,8 @@ def append_slot( if not self.free_block_ids: if not self.try_free_blocks(1): raise RuntimeError("No available cache blocks") - new_block_id = self.free_block_ids[0] - self._allocate_partial_block(new_block_id) - block_table.append(new_block_id) + new_block = self._allocate_partial_block() + block_table.append(new_block.block_id) # Calculate slot last_block_id = block_table[-1] @@ -316,11 +389,14 @@ def append_slot( return block_table, slot_id + # Reference management + def free_blocks(self, block_table: List[int]): """Decrease reference count for all blocks. Blocks with ref_count=0 are not immediately freed to allow reuse.""" for block_id in reversed(block_table): block = self.blocks[block_id] + assert block.ref_count > 0, "block ref_count must be greater than 0" block.ref_count -= 1 def try_free_blocks(self, num_required: int) -> bool: @@ -336,17 +412,73 @@ def try_free_blocks(self, num_required: int) -> bool: return self.can_allocate(num_required) - def get_num_free_blocks(self) -> int: - return len(self.free_block_ids) + # PD-disaggregation specific - def get_total_usable_blocks(self) -> int: - freeable_used_blocks = sum( - 1 for bid in self.used_block_ids if self.blocks[bid].ref_count == 0 - ) - return len(self.free_block_ids) + freeable_used_blocks + def update_blocks_hash(self, block_table: List[int], num_local_cached_tokens: int): + """Register hashes for blocks beyond the locally cached prefix into the lookup table. - def __repr__(self): - return ( - f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, " - f"free={len(self.free_block_ids)}, used={len(self.used_block_ids)})" + Called on the decode node after receiving KV data from the prefill node, + so that subsequent requests can hit these blocks via prefix caching. + Only full blocks (with a valid hash) are registered; partial blocks are skipped. + + Args: + block_table: Block IDs for the current request. + num_local_cached_tokens: Number of locally cached tokens (must be a multiple of + block_size). + """ + assert num_local_cached_tokens % self.block_size == 0, ( + "num_local_cached_tokens must be multiple of block_size" ) + for idx in range(num_local_cached_tokens // self.block_size, len(block_table)): + block_id = block_table[idx] + block = self.blocks[block_id] + if block.hash != -1: + self.hash_to_block_id[block.hash] = block_id + + def update_blocks_slot( + self, block_table: List[int], num_computed_tokens: int, total_tokens: int + ) -> List[int]: + """Build the slot mapping for tokens that still need to be computed. + + Used on the decode node after a partial KV transfer failure to reconstruct + the slot mapping covering [num_computed_tokens, total_tokens). + + Args: + block_table: Block IDs for the current request. + num_computed_tokens: Number of tokens already computed (may not be + a multiple of block_size). + total_tokens: Total token count for this request. + + Returns: + Slot IDs for the range [num_computed_tokens, total_tokens). + """ + bs = self.block_size + new_slot_mapping = [] + + start_block = num_computed_tokens // bs + start_offset = num_computed_tokens % bs + + last_token_idx = total_tokens - 1 + end_block = last_token_idx // bs + end_offset = last_token_idx % bs + 1 + + if start_block == end_block: + block_id = block_table[start_block] + base = block_id * bs + new_slot_mapping.extend(range(base + start_offset, base + end_offset)) + return new_slot_mapping + + block_id = block_table[start_block] + base = block_id * bs + new_slot_mapping.extend(range(base + start_offset, base + bs)) + + for idx in range(start_block + 1, end_block): + block_id = block_table[idx] + base = block_id * bs + new_slot_mapping.extend(range(base, base + bs)) + + block_id = block_table[end_block] + base = block_id * bs + new_slot_mapping.extend(range(base, base + end_offset)) + + return new_slot_mapping diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 0149d715..9177e3fa 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -6,15 +6,14 @@ - AsyncLLM class for asynchronous streaming (server use) """ +import os import asyncio import time import uuid import logging +import janus import threading from typing import List, Optional, Union, AsyncIterator -from dataclasses import dataclass - -import infinicore from infinilm.llm.request import ( InferenceRequest, @@ -22,177 +21,107 @@ TokenOutput, FinishReason, ) + +from infinilm.llm.model_runner.model_runner import ModelRunner from infinilm.llm.sampling_params import SamplingParams from infinilm.llm.scheduler import Scheduler from infinilm.llm.static_scheduler import StaticScheduler -from infinilm.processors import AutoInfinilmProcessor -from infinilm.distributed import DistConfig -from infinilm.infer_engine import InferEngine -from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig -from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.multimodal.multimodal import resolve_multimodal_inputs +from infinilm.config.kv_transfer import KVTransferConfig +from infinilm.config.engine_config import EngineConfig +from infinilm.kv_connector import KVConnectorRole, KVConnectorFactory logger = logging.getLogger(__name__) -@dataclass -class EngineConfig: - """Configuration for LLM Engine. - - Attributes: - model_path: Path to the model directory. - device: Device type string ('cpu', 'cuda', 'mlu', etc.). - dtype: Data type string ('float16', 'bfloat16', 'float32'). - tensor_parallel_size: Number of devices for tensor parallelism. - cache_type: Cache type ('paged' or 'static'). - max_batch_size: Maximum batch size for inference (only for paged cache). - max_tokens: Default maximum tokens to generate. - num_blocks: Number of KV cache blocks (only for paged cache). - block_size: Size of each KV cache block (only for paged cache). - max_cache_len: Maximum sequence length (only for static cache). - temperature: Default sampling temperature. - top_p: Default top-p sampling parameter. - top_k: Default top-k sampling parameter. - enable_graph: Whether to enable graph compiling. - attn_backend: Attention backend to use ('default', 'flash-attn'). - skip_load: Whether to skip loading model weights (for testing). - """ - - model_path: str - device: str = "cuda" - dtype: str = "float16" - tensor_parallel_size: int = 1 - cache_type: str = "paged" # "paged" or "static" - max_batch_size: int = 16 - max_tokens: int = 4096 - num_blocks: int = 512 - block_size: int = 256 - max_cache_len: int = 4096 - temperature: float = 1.0 - top_p: float = 0.8 - top_k: int = 1 - enable_graph: bool = False - attn_backend: str = "default" - skip_load: bool = False - - class LLMEngine: """Low-level LLM engine that handles inference execution.""" def __init__(self, config: EngineConfig): self.config = config - # Initialize device and dtype - self._init_device() + self.model_runner = ModelRunner(config) - # Initialize model engine - self.model_engine = InferEngine( - model_path=config.model_path, - device=self.device, - distributed_config=DistConfig(config.tensor_parallel_size), - enable_graph_compiling=config.enable_graph, - attention_backend=config.attn_backend, - ) - - # Load model weights - if not self.config.skip_load: - load_model_state_dict_by_file( - self.model_engine, config.model_path, dtype=self.model_engine.dtype - ) + self.device = self.model_runner.device + self.dtype = self.model_runner.dtype - # Initialize processor/tokenizer - self.processor = AutoInfinilmProcessor.from_pretrained(config.model_path) + # Initialize processor + self.processor = self.model_runner.processor self.tokenizer = self.processor.get_tokenizer() # Initialize KV cache based on cache type if config.cache_type == "static": - cache_config = StaticKVCacheConfig( - max_batch_size=1, max_cache_len=config.max_cache_len - ) self.scheduler = StaticScheduler(max_cache_len=config.max_cache_len) logger.info( f"Using Static KV Cache with max_cache_len={config.max_cache_len}" ) elif config.cache_type == "paged": - cache_config = PagedKVCacheConfig( - num_blocks=config.num_blocks, block_size=config.block_size + connector = None + if config.kv_transfer_config and config.kv_transfer_config.kv_connector: + connector = KVConnectorFactory.create_connector( + connector_name=config.kv_transfer_config.kv_connector, + role=KVConnectorRole.SCHEDULER, + kv_transfer_config=config.kv_transfer_config, + ) + logger.info( + f"KV Connector created: {config.kv_transfer_config.kv_connector} " + f"(role={config.kv_transfer_config.kv_role})" + ) + + max_position_embeddings = self.model_runner.model_engine.hf_config[ + "max_position_embeddings" + ] + max_num_batched_tokens = int( + os.getenv("INFINILM_MAX_NUM_BATCHED_TOKENS", max_position_embeddings) ) + assert 1024 <= max_num_batched_tokens <= max_position_embeddings + self.scheduler = Scheduler( max_batch_size=config.max_batch_size, num_blocks=config.num_blocks, block_size=config.block_size, + max_num_batched_tokens=max_num_batched_tokens, + connector=connector, ) logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}") else: raise ValueError(f"Unsupported cache_type: {config.cache_type}") - self.model_engine.reset_cache(cache_config) self.cache_type = config.cache_type # Get EOS token IDs from model config - self.eos_token_ids = self.model_engine.eos_token_id or [] + self.eos_token_ids = self.model_runner.eos_token_id or [] if isinstance(self.eos_token_ids, int): self.eos_token_ids = [self.eos_token_ids] logger.info( f"LLMEngine initialized with model at {config.model_path} " - f"on device {config.device}" + f"on device {config.device}, " f"enable_graph={config.enable_graph}" ) - def _init_device(self): - """Initialize infinicore device and dtype.""" - supported_devices = ["cpu", "cuda", "mlu", "musa", "npu"] - device_str = self.config.device - if device_str not in supported_devices: - raise ValueError( - f"Unsupported device: '{device_str}'. " - f"Supported devices: {supported_devices}" - ) - self.device = infinicore.device(device_str, 0) - - dtype_map = { - "float32": infinicore.float32, - "float16": infinicore.float16, - "bfloat16": infinicore.bfloat16, - } - - if self.config.dtype not in dtype_map: - raise ValueError( - f"Unsupported dtype: '{self.config.dtype}'. " - f"Supported dtypes: {list(dtype_map.keys())}" - ) - - self.dtype = dtype_map[self.config.dtype] - def add_request(self, request: InferenceRequest): """Add a request to the scheduler.""" self.scheduler.add_request(request) - def step(self) -> tuple[list[InferenceRequest], list[tuple]]: + def step(self) -> tuple[bool, list[tuple]]: """Run one inference step. Returns: A tuple of: - - scheduled_requests: Requests that were scheduled and processed in this step. + - did_work - pending: Pending streaming outputs as (async_queue, TokenOutput) pairs. """ - # Schedule requests + # Schedule the next unit of work, which may be model execution, + # connector control metadata, or both. scheduler_output = self.scheduler.schedule() - if scheduler_output is None or not scheduler_output.scheduled_requests: - return [], [] - - # Build model inputs - model_input = self.processor.build_model_inputs( - scheduler_output, - self.config.temperature, - self.config.top_p, - self.config.top_k, - ) + if scheduler_output is None: + return False, [] - # Run inference - sampled_tokens = self.model_engine.forward(**model_input) - sampled_tokens_list = sampled_tokens.to_numpy().tolist() + # Execute model + runner_output = self.model_runner.execute_model(scheduler_output) + sampled_tokens_list = runner_output.sampled_token_ids + self.scheduler.update_from_output(runner_output) # Update request status pending = self._update_requests( @@ -201,7 +130,18 @@ def step(self) -> tuple[list[InferenceRequest], list[tuple]]: sampled_tokens_list, ) - return scheduler_output.scheduled_requests, pending + # Return False (no immediate work) only when no requests were scheduled + # and no KV transfers completed in this step. + if not scheduler_output.scheduled_requests: + if not runner_output.kv_connector_output or ( + not getattr(runner_output.kv_connector_output, "finished_sending", None) + and not getattr( + runner_output.kv_connector_output, "finished_recving", None + ) + ): + return False, pending + + return True, pending def _update_requests( self, @@ -213,7 +153,7 @@ def _update_requests( if is_prefill: match self.cache_type: case "paged": - self.scheduler.cache_manager.reset_req_blocks() + pass case "static": self.scheduler.update_cache() case _: @@ -224,13 +164,14 @@ def _update_requests( logger.info( f"Request {req.request_id} aborted by client, skipping update" ) + # close() may have set _aborted=True without setting a terminal status + # (status still RUNNING). + if not req.is_finished(): + req.mark_canceled() continue - if req.is_prefill: - req.is_prefill = False - req.generated_token_ids.append(token_id) - pending_tokens = req.generated_token_ids[req._pending_token_offset :] + pending_tokens = req.generated_token_ids[req._token_decode_offset :] delta = self.tokenizer.decode(pending_tokens) holds_back = bool(delta) and delta.endswith("\ufffd") @@ -238,7 +179,7 @@ def _update_requests( if not holds_back: req.generated_text = last_committed_text + delta - req._pending_token_offset = len(req.generated_token_ids) + req._token_decode_offset = len(req.generated_token_ids) is_finished = self._check_request_finished(req, token_id) @@ -259,11 +200,9 @@ def _update_requests( ): token_text = "" else: - token_text = req.generated_text[ - req._stream_last_yielded_length : - ] + token_text = req.generated_text[req._text_output_offset :] if token_text: - req._stream_last_yielded_length = len(req.generated_text) + req._text_output_offset = len(req.generated_text) if is_finished: req.mark_finished(req.finish_reason) @@ -553,6 +492,7 @@ def __init__( top_k: int = 1, enable_graph: bool = False, attn_backend: str = "default", + kv_transfer_config: Optional[KVTransferConfig] = None, ): """Initialize AsyncLLMEngine. @@ -572,6 +512,9 @@ def __init__( top_k: Default top-k sampling parameter. enable_graph: Whether to enable graph compiling. attn_backend: Attention backend to use ('default', 'flash-attn'). + kv_connector: KV connector type ('MooncakeConnector'). + kv_role: Role in KV connector ('kv_producer' or 'kv_consumer'). + kv_connector_extra_config: Extra config dict for KV connector. """ config = EngineConfig( model_path=model_path, @@ -589,6 +532,7 @@ def __init__( top_k=top_k, enable_graph=enable_graph, attn_backend=attn_backend, + kv_transfer_config=kv_transfer_config, ) self.engine = LLMEngine(config) self.config = config @@ -597,6 +541,7 @@ def __init__( self._step_thread: Optional[threading.Thread] = None self._loop: Optional[asyncio.AbstractEventLoop] = None self._healthy = True + self._abort_queue: Optional[janus.Queue] = None def is_healthy(self) -> bool: return bool(self._healthy) @@ -608,6 +553,7 @@ def start(self): return self._loop = asyncio.get_running_loop() + self._abort_queue = janus.Queue() self._running = True self._step_thread = threading.Thread( target=self._step_loop, daemon=True, name="AsyncLLMEngineStepThread" @@ -626,13 +572,69 @@ def stop(self): self._step_thread.join(timeout=5) logger.info("AsyncLLMEngine stopped") + def add_aborted_req( + self, + req: InferenceRequest, + reason: FinishReason = FinishReason.CANCELED, + ): + """Submit an abort request from async side to the step thread. + + The step thread processes this in _drain_abort_queue() before each schedule(). + """ + if self._abort_queue is not None: + self._abort_queue.sync_q.put((req, reason)) + + def _drain_abort_queue(self): + """Process all pending abort requests before each schedule() call. + + Runs in the step thread (sync context). Guarantees mark_*() is called + before schedule() so is_finished() checks in waiting/running queue loops + work correctly. Puts a final TokenOutput into the output queue to unblock + stream_request when _stream_chat is still alive after abort. + """ + if self._abort_queue is None: + return + while True: + try: + req, reason = self._abort_queue.sync_q.get_nowait() + except Exception: + break + + if req.is_finished(): + continue + + if reason == FinishReason.CANCELED: + req.mark_canceled() + elif reason == FinishReason.TIMEOUT: + req.mark_timeout() + else: + req.mark_failed(reason) + + # Put a final token to unblock stream_request. + # If Starlette already cancelled _stream_chat, aclose() may have closed + # the queue; put_nowait will raise and we silently ignore it. + if req._output_queue is not None: + final = TokenOutput( + request_id=req.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=req.finish_reason, + generated_text=req.generated_text, + ) + try: + req.output_queue.sync_q.put_nowait(final) + except Exception: + pass + def _step_loop(self): """Background loop that runs inference steps.""" while self._running: try: - requests, pending = self.engine.step() - if not requests: - time.sleep(0.01) + self._drain_abort_queue() + did_work, pending = self.engine.step() + if not did_work: + time.sleep(0.003) elif pending: self._loop.call_soon_threadsafe(self._batch_put, pending) except Exception as e: @@ -663,7 +665,6 @@ def add_request( request_id: Optional[str] = None, # For server use request_data: Optional[dict] = None, - http_request: Optional[any] = None, ) -> InferenceRequest: """Add a request to the engine. @@ -693,7 +694,6 @@ def add_request( sampling_params: Sampling parameters. request_id: Optional request ID. request_data: Optional request data dict (for server use). - http_request: Optional HTTP request object (for server use). Returns: The created InferenceRequest object. @@ -709,13 +709,13 @@ def add_request( elif prompt is not None: prompt_token_ids = self.engine.tokenize(prompt) else: - assert ( - messages is not None - ), "Either messages or prompt/prompt_token_ids must be provided" + assert messages is not None, ( + "Either messages or prompt/prompt_token_ids must be provided" + ) - assert ( - apply_chat_template - ), "apply_chat_template needs to be true for multi-role conversation" + assert apply_chat_template, ( + "apply_chat_template needs to be true for multi-role conversation" + ) prompt = self.engine.apply_chat_template( messages, add_generation_prompt=add_generation_prompt @@ -754,9 +754,12 @@ def add_request( sampling_params=sampling_params, eos_token_ids=self.engine.eos_token_ids, request_data=request_data, - http_request=http_request, ) + if request_data and "kv_transfer_params" in request_data: + kv_params = request_data["kv_transfer_params"] + request.kv_transfer_params = kv_params + # Initialize output queue for streaming _ = request.output_queue @@ -769,7 +772,6 @@ def add_chat_request( sampling_params: Optional[SamplingParams] = None, request_id: Optional[str] = None, request_data: Optional[dict] = None, - http_request: Optional[any] = None, add_generation_prompt: bool = True, **kwargs, ) -> InferenceRequest: @@ -780,7 +782,6 @@ def add_chat_request( sampling_params: Sampling parameters. request_id: Optional request ID. request_data: Optional request data dict. - http_request: Optional HTTP request object. Returns: The created InferenceRequest object. @@ -793,7 +794,6 @@ def add_chat_request( sampling_params=sampling_params, request_id=request_id, request_data=request_data, - http_request=http_request, ) async def stream_request( @@ -814,53 +814,54 @@ async def stream_request( import asyncio start = time.time() - while True: - try: - if request_timeout and time.time() - start > float(request_timeout): - request.mark_timeout() - yield TokenOutput( - request_id=request.request_id, - token_id=-1, - token_text="", - finished=True, - finish_reason=FinishReason.TIMEOUT, - generated_text=request.generated_text, + try: + while True: + try: + if request_timeout and time.time() - start > float(request_timeout): + logger.warning( + f"Request {request.request_id} exceeded request timeout of {request_timeout} seconds" + ) + self.add_aborted_req(request, FinishReason.TIMEOUT) + + token_output = await asyncio.wait_for( + request.output_queue.async_q.get(), timeout=timeout ) - break - token_output = await asyncio.wait_for( - request.output_queue.async_q.get(), timeout=timeout - ) - - request.output_queue.async_q.task_done() + request.output_queue.async_q.task_done() - yield token_output + yield token_output - if token_output.finished: - break - except asyncio.TimeoutError: - logger.warning( - f"Timeout while waiting for token from request {request.request_id}" - ) - if request.is_aborted(): - while not request.output_queue.async_q.empty(): - try: - token_output = request.output_queue.async_q.get_nowait() - request.output_queue.async_q.task_done() - yield token_output - except asyncio.QueueEmpty: - break - - yield TokenOutput( - request_id=request.request_id, - token_id=-1, - token_text="", - finished=True, - finish_reason=request.finish_reason, - generated_text=request.generated_text, + if token_output.finished: + break + except asyncio.TimeoutError: + logger.warning( + f"Timeout while waiting for token from request {request.request_id}" + ) + if request.is_aborted(): + while not request.output_queue.async_q.empty(): + try: + token_output = request.output_queue.async_q.get_nowait() + request.output_queue.async_q.task_done() + yield token_output + except asyncio.QueueEmpty: + break + + yield TokenOutput( + request_id=request.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=request.finish_reason, + generated_text=request.generated_text, + ) + break + continue + except Exception as e: + logger.error( + f"Error while streaming request {request.request_id}: {e}" ) break - continue - except Exception as e: - logger.error(f"Error while streaming request {request.request_id}: {e}") - break + finally: + # Unified cleanup point: runs whether the loop exits normally, + # via exception, or via aclose() (GeneratorExit from Starlette). + await request.close() diff --git a/python/infinilm/llm/model_runner/model_runner.py b/python/infinilm/llm/model_runner/model_runner.py new file mode 100644 index 00000000..e551da0c --- /dev/null +++ b/python/infinilm/llm/model_runner/model_runner.py @@ -0,0 +1,215 @@ +import logging +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Generator + +import infinicore + +from infinilm.distributed import DistConfig +from infinilm.infer_engine import InferEngine +from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig +from infinilm.modeling_utils import load_model_state_dict_by_file +from infinilm.config.engine_config import EngineConfig +from infinilm.kv_connector import ( + KVConnectorRole, + KVConnectorFactory, +) +from infinilm.processors import AutoInfinilmProcessor + +logger = logging.getLogger(__name__) + + +@dataclass +class KVConnectorOutput: + finished_sending: set[str] | None = None + finished_recving: set[str] | None = None + + # consumer failed to recv + failed_recving: set[str] | None = None + + # IDs of externally computed KV blocks that failed to load. + # Requests referencing these blocks should be rescheduled to recompute them + invalid_block_ids: set[int] = field(default_factory=set) # not used + kv_connector_stats = None # not used + + +@dataclass +class ModelRunnerOutput: + # [num_reqs] + req_ids: list[str] = field(default_factory=list) + sampled_token_ids: list[int] = field(default_factory=list) + kv_connector_output: KVConnectorOutput | None = None + + +class ModelRunner: + def __init__(self, config: EngineConfig): + self.config = config + self.kv_transfer_config = config.kv_transfer_config + logger.info(f"kv_transfer_config: {self.kv_transfer_config}") + + self._init_device() + + # Initialize KV cache based on cache type + if config.cache_type == "static": + cache_config = StaticKVCacheConfig( + max_batch_size=1, max_cache_len=config.max_cache_len + ) + logger.info( + f"Using Static KV Cache with max_cache_len={config.max_cache_len}" + ) + elif config.cache_type == "paged": + cache_config = PagedKVCacheConfig( + num_blocks=config.num_blocks, block_size=config.block_size + ) + logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}") + else: + raise ValueError(f"Unsupported cache_type: {config.cache_type}") + + # Initialize model engine + self.model_engine = InferEngine( + model_path=config.model_path, + device=self.device, + distributed_config=DistConfig(config.tensor_parallel_size), + cache_config=cache_config, + enable_graph_compiling=config.enable_graph, + attention_backend=config.attn_backend, + ) + + # Load model weights + if not self.config.skip_load: + load_model_state_dict_by_file( + self.model_engine, config.model_path, dtype=self.model_engine.dtype + ) + + # Initialize processor + self.processor = AutoInfinilmProcessor.from_pretrained(config.model_path) + + # Initialize KV connector + self.kv_connector = None + if ( + self.kv_transfer_config is not None + and self.kv_transfer_config.kv_connector + ): + connector_name = self.kv_transfer_config.kv_connector + self.kv_connector = KVConnectorFactory.create_connector( + connector_name=connector_name, + role=KVConnectorRole.WORKER, + kv_transfer_config=self.kv_transfer_config, + ) + + kv_cache_list = self.model_engine.get_kv_cache() + assert len(kv_cache_list) == self.config.tensor_parallel_size + + kv_caches = {} + for rank_idx, kv_cache_vec in enumerate(kv_cache_list): + for layer_idx, layer_kv_cache in enumerate(kv_cache_vec): + # print(layer_kv.shape) # shape:[2, 8, 8, 256, 128] + key_name = ( + f"rank.{rank_idx}.model.layers.{layer_idx}.self_attn.attn" + ) + kv_caches[key_name] = layer_kv_cache + + self.kv_connector.register_kv_caches(kv_caches) + + @property + def model_type(self): + return self.model_engine.model_type + + @property + def eos_token_id(self): + return self.model_engine.eos_token_id + + def _init_device(self): + """Initialize infinicore device and dtype.""" + supported_devices = ["cpu", "cuda", "mlu", "musa", "npu"] + device_str = self.config.device + if device_str not in supported_devices: + raise ValueError( + f"Unsupported device: '{device_str}'. " + f"Supported devices: {supported_devices}" + ) + self.device = infinicore.device(device_str, 0) + + dtype_map = { + "float32": infinicore.float32, + "float16": infinicore.float16, + "bfloat16": infinicore.bfloat16, + } + + if self.config.dtype not in dtype_map: + raise ValueError( + f"Unsupported dtype: '{self.config.dtype}'. " + f"Supported dtypes: {list(dtype_map.keys())}" + ) + + self.dtype = dtype_map[self.config.dtype] + + def execute_model(self, scheduler_output) -> ModelRunnerOutput: + sampled_tokens_list = [] + kv_connector_output = None + + if self.kv_connector is None: + sampled_tokens_list = self._model_forward(scheduler_output) + else: + with self.maybe_get_kv_connector_output( + scheduler_output, + ) as kv_connector_output: + if scheduler_output.num_requests > 0: + sampled_tokens_list = self._model_forward(scheduler_output) + + # model_runner_output + req_ids = [] + for i in range(scheduler_output.num_requests): + req_ids.append(scheduler_output.scheduled_requests[i].request_id) + + return ModelRunnerOutput( + req_ids=req_ids, + sampled_token_ids=sampled_tokens_list, + kv_connector_output=kv_connector_output, + ) + + def _model_forward(self, scheduler_output): + # Build model inputs + model_input = self.processor.build_model_inputs( + scheduler_output, + self.config.temperature, + self.config.top_p, + self.config.top_k, + ) + + # Run inference + sampled_tokens = self.model_engine.forward(**model_input) + sampled_tokens_list = sampled_tokens.to_numpy().tolist() + + return sampled_tokens_list + + @contextmanager + def maybe_get_kv_connector_output( + self, scheduler_output: Any + ) -> Generator[KVConnectorOutput, None, None]: + """Context manager for KV connector operations around model forward.""" + + output = KVConnectorOutput() + assert scheduler_output.kv_connector_metadata is not None + + self.kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata + ) + + self.kv_connector.start_load_kv() + + try: + yield output + finally: + output.finished_sending, output.failed_recving, output.finished_recving = ( + self.kv_connector.get_finished("finished_req_ids") + ) + output.invalid_block_ids = ( + self.kv_connector.get_block_ids_with_load_errors() + ) + output.kv_connector_stats = self.kv_connector.get_kv_connector_stats() + + def close(self) -> None: + """Release resources held by the KV connector.""" + if self.kv_connector is not None: + self.kv_connector.shutdown() diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index a94c2f68..16f12efe 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -18,23 +18,36 @@ class RequestStatus(Enum): """Status of an inference request.""" + # Pending WAITING = "waiting" + WAITING_FOR_REMOTE_KVS = "waiting_for_remote_kvs" + + # Active RUNNING = "running" + + # Successful terminal FINISHED = "finished" + + # Abnormal terminal CANCELED = "canceled" - FAILED = "failed" TIMEOUT = "timeout" + FAILED = "failed" class FinishReason(Enum): """Reason for finishing generation.""" - STOP = "stop" - LENGTH = "length" + # Normal completion EOS_TOKEN = "eos_token" STOP_STRING = "stop_string" - TIMEOUT = "timeout" + STOP = "stop" + + # Controlled truncation + LENGTH = "length" + + # Abnormal termination CANCELED = "canceled" + TIMEOUT = "timeout" ERROR = "error" @@ -112,7 +125,6 @@ def __init__( arrival_time: Optional[float] = None, # For server use request_data: Optional[dict] = None, - http_request: Optional[Any] = None, ): self.arrival_time: float = arrival_time or time.time() self.finished_time: Optional[float] = None @@ -120,44 +132,50 @@ def __init__( # Request metadata self.request_id: str = request_id self.prompt: Optional[str] = prompt - self.prompt_token_ids: List[int] = prompt_token_ids or [] + self.prompt_token_ids: List[int] = ( + prompt_token_ids if prompt_token_ids is not None else [] + ) self.prompt_length: int = len(self.prompt_token_ids) self.processed_inputs: Optional[dict] = processed_inputs self.mm_token_index_mappings: Optional[List[dict]] = mm_token_index_mappings + self.priority: int = 0 - # Sampling parameters + # Sampling & stopping criteria self.sampling_params: SamplingParams = sampling_params or SamplingParams() - - # EOS token IDs (from model config) - self.eos_token_ids: List[int] = eos_token_ids or [] + self.eos_token_ids: List[int] = ( + eos_token_ids if eos_token_ids is not None else [] + ) # Generation state self.generated_token_ids: List[int] = [] - self.generated_text: str = "" - self.is_prefill: bool = True + self.generated_text: str = ( + "" # generated_text == tokenizer.decode(generated_token_ids[:_token_decode_offset]) + ) self.status: RequestStatus = RequestStatus.WAITING self.finish_reason: Optional[FinishReason] = None - self.priority: int = 0 # KV cache management - self.cache_id: Optional[int] = None self.block_table: List[int] = [] self.slot_mapping: List[int] = [] - self.num_cached_tokens: int = 0 + self.num_local_cached_tokens: int = ( + 0 # Number of locally cached (prefix-hit) tokens + ) + self.num_computed_tokens: int = 0 # Total tokens computed (local + remote) self.num_blocks: int = 0 + # PD disaggregation support + self.kv_transfer_params: Optional[dict] = ( + None # KV transfer parameters from the router + ) + # For server use self.request_data: Optional[dict] = request_data - self.http_request: Optional[Any] = http_request - # Output management (for async streaming) + # Async output & streaming self._output_queue: Optional[janus.Queue] = None - self._aborted = False - - # Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer) - # Used by the engine to compute "delta" text chunks from a full decode. - self._stream_last_yielded_length: int = 0 - self._pending_token_offset: int = 0 + self._aborted: bool = False + self._text_output_offset: int = 0 + self._token_decode_offset: int = 0 @property def output_queue(self) -> janus.Queue: @@ -253,7 +271,7 @@ async def close(self): self._output_queue.close() try: - await asyncio.wait_for(self._output_queue.wait_closed(), timeout=0.5) + await asyncio.wait_for(self._output_queue.wait_closed(), timeout=1.0) except asyncio.TimeoutError: logger.warning("wait_closed timeout, force close") diff --git a/python/infinilm/llm/sampling_params.py b/python/infinilm/llm/sampling_params.py index cdde6e93..e9632c17 100644 --- a/python/infinilm/llm/sampling_params.py +++ b/python/infinilm/llm/sampling_params.py @@ -13,7 +13,7 @@ class SamplingParams: temperature: float = 1.0 top_p: float = 0.8 top_k: int = 1 - max_tokens: Optional[int] = None + max_tokens: int = 512 stop: Optional[List[str]] = None stop_token_ids: Optional[List[int]] = ( None # Placeholder for future usage, not currently handled diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index c5f4921a..c99c54f5 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -23,6 +23,7 @@ def __init__( self.scheduled_requests = scheduled_requests self.num_requests = len(scheduled_requests) self.is_prefill = is_prefill + self.kv_connector_metadata = None class Scheduler: @@ -39,26 +40,60 @@ def __init__( max_batch_size: int = 16, num_blocks: int = 512, block_size: int = 256, + max_num_batched_tokens: int = 1024, + connector=None, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() self.max_batch_size = max_batch_size + self.finished_receiving_kv_req_ids: set[str] = set() + self.failed_receiving_kv_req_ids: set[str] = set() + self.pending_free_blocks: dict[str, list[int]] = {} + self.pending_kv_decode_blocks: int = 0 + self.remote_kv_requests: dict[str, InferenceRequest] = {} + self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) self.block_size = block_size + self.max_num_batched_tokens = max_num_batched_tokens + self.connector = connector def add_request(self, request: InferenceRequest): if request is not None: request.status = RequestStatus.WAITING self.waiting_queue.sync_q.put(request) + def _exceeds_token_budget( + self, + current_num_batched_tokens: int, + num_tokens_this_step: int, + num_scheduled_requests: int, + ) -> bool: + """Return True when adding this request should be deferred for token budget. + + A single request is always allowed to make progress, even if it is larger + than max_num_batched_tokens. + """ + if num_scheduled_requests == 0: + return False + return ( + current_num_batched_tokens + num_tokens_this_step + > self.max_num_batched_tokens + ) + def schedule(self) -> Optional[SchedulerOutput]: """Schedule and return batch of requests to execute.""" + deferred_requests = [] scheduled_requests = [] is_prefill = False + current_num_batched_tokens = 0 + current_prefill_extra_blocks = 0 # Process Waiting queue (prefill phase) - while len(scheduled_requests) < self.max_batch_size: + while ( + len(scheduled_requests) < self.max_batch_size + and current_num_batched_tokens < self.max_num_batched_tokens + ): try: req = self.waiting_queue.sync_q.get_nowait() except queue.Empty: @@ -68,40 +103,144 @@ def schedule(self) -> Optional[SchedulerOutput]: self.complete_requests([req]) continue - if not self.can_accept_request(req): - self.waiting_queue.sync_q.put(req) - break + req_tokens = req.get_input_tokens() - # Skip requests that were already finished (e.g., timed out/canceled while waiting) - if req.is_finished(): - self.complete_requests([req]) - continue + if req.num_computed_tokens == 0: + ( + cached_block_table, + num_local_computed_tokens, + blocks_blueprint, + ) = self.cache_manager.get_computed_blocks( + req_tokens, req.get_mm_token_index_mappings() + ) + if self.connector is not None: + ext_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + req, num_local_computed_tokens + ) + ) + num_external_computed_tokens = ext_tokens + else: + load_kv_async = False + num_external_computed_tokens = 0 - req_tokens = req.get_input_tokens() - num_required_blocks = req.get_num_blocks_required(self.block_size) + num_computed_tokens = ( + num_local_computed_tokens + num_external_computed_tokens + ) + if load_kv_async: + num_computed_tokens -= 1 + num_new_tokens = req.get_prompt_length() - num_computed_tokens + + # Early token budget check: skip can_accept_request and allocate_slots + # for requests that would exceed the per-schedule token budget. + if not load_kv_async: + num_tokens_this_step = ( + req.get_prompt_length() - num_local_computed_tokens + ) + if self._exceeds_token_budget( + current_num_batched_tokens, + num_tokens_this_step, + len(scheduled_requests), + ): + if num_local_computed_tokens > 0: + self.cache_manager.free_blocks(cached_block_table) + deferred_requests.append(req) + break + + if not self.can_accept_request( + req, + num_local_computed_tokens, + current_prefill_extra_blocks, + ): + logger.warning( + "Insufficient KV cache blocks for request %s, deferring.", + req.request_id, + ) - if not self.cache_manager.can_allocate(num_required_blocks): - if not self.cache_manager.try_free_blocks(num_required_blocks): - raise RuntimeError("No available cache blocks for new request") + if num_local_computed_tokens > 0: + self.cache_manager.free_blocks(cached_block_table) + deferred_requests.append(req) + break + + req_blocks, slot_mapping = self.cache_manager.allocate_slots( + req_tokens, + num_new_tokens, + num_computed_tokens=num_computed_tokens, + cached_block_table=cached_block_table, + blocks_blueprint=blocks_blueprint, + delay_cache_blocks=load_kv_async, + ) - # Allocate blocks with automatic prefix caching support - req.block_table, req.slot_mapping, req.num_cached_tokens = ( - self.cache_manager.allocate_blocks( - req_tokens, req.block_table, req.get_mm_token_index_mappings() + if req_blocks is None: + logger.warning( + "Failed to allocate KV cache blocks for request: %s", + req.request_id, + ) + if num_local_computed_tokens > 0: + self.cache_manager.free_blocks(cached_block_table) + deferred_requests.append(req) + break + + req.block_table = req_blocks + req.slot_mapping = slot_mapping + req.num_blocks = len(req_blocks) + req.num_local_cached_tokens = num_local_computed_tokens + req.num_computed_tokens = num_computed_tokens + + if self.connector is not None: + self.connector.update_state_after_alloc( + req, + req.block_table, + num_external_computed_tokens, + self.block_size, + ) + else: + load_kv_async = False + num_tokens_this_step = ( + req.get_prompt_length() - req.num_local_cached_tokens + ) + if self._exceeds_token_budget( + current_num_batched_tokens, + num_tokens_this_step, + len(scheduled_requests), + ): + deferred_requests.append(req) + break + self.cache_manager.update_blocks_hash( + req.block_table, req.num_local_cached_tokens ) - ) - req.num_blocks = len(req.block_table) - req.status = RequestStatus.RUNNING + if load_kv_async: + req.status = RequestStatus.WAITING_FOR_REMOTE_KVS + self.remote_kv_requests[req.request_id] = req + self.pending_kv_decode_blocks += ( + req.sampling_params.max_tokens + self.block_size - 1 + ) // self.block_size + continue + + current_prefill_extra_blocks += self._get_prefill_extra_blocks(req) scheduled_requests.append(req) + num_tokens_this_step = req.get_prompt_length() - req.num_local_cached_tokens + current_num_batched_tokens += num_tokens_this_step + + req.status = RequestStatus.RUNNING + + if deferred_requests: + for req in deferred_requests: + self.waiting_queue.sync_q.put(req) + # Return prefill batch if any waiting requests were scheduled if scheduled_requests: is_prefill = True - return SchedulerOutput( + scheduler_output = SchedulerOutput( scheduled_requests=scheduled_requests, is_prefill=is_prefill, ) + if self.connector is not None: + meta = self.connector.build_connector_meta() + scheduler_output.kv_connector_metadata = meta + return scheduler_output # Process Running queue (decode phase) while len(scheduled_requests) < self.max_batch_size: @@ -121,22 +260,90 @@ def schedule(self) -> Optional[SchedulerOutput]: ) req.slot_mapping = [new_slot] req.num_blocks = len(req.block_table) - req.num_cached_tokens = req.get_total_length() - 1 + req.num_local_cached_tokens = req.get_total_length() - 1 scheduled_requests.append(req) except RuntimeError as e: raise RuntimeError("No available cache blocks for new token") from e + # Promote completed remote KV transfers (lower priority than running queue). + # Cleanup (is_finished, failed re-queue) runs unconditionally; batch append only if slots remain. + if self.connector is not None and self.remote_kv_requests: + for req_id in list(self.remote_kv_requests.keys()): + req = self.remote_kv_requests[req_id] + if req.is_finished(): + self.complete_requests([req]) + continue + if req_id in self.failed_receiving_kv_req_ids: + logger.warning( + f"Request {req_id[:8]}... failed receiving KV, re-queuing for prefill." + ) + self.update_waiting_for_remote_kv(req) + req.status = RequestStatus.WAITING + self.waiting_queue.sync_q.put(req) + elif req_id in self.finished_receiving_kv_req_ids: + if len(scheduled_requests) < self.max_batch_size: + logger.info( + f"Request {req_id[:8]}... finished receiving KV, scheduling for decode." + ) + self.update_waiting_for_remote_kv(req) + req.status = RequestStatus.RUNNING + scheduled_requests.append(req) + else: + break # Defer promotion to next schedule() if batch is full + # Return decode batch if any running requests were scheduled if scheduled_requests: is_prefill = False - return SchedulerOutput( + scheduler_output = SchedulerOutput( scheduled_requests=scheduled_requests, is_prefill=is_prefill, ) + if self.connector is not None: + meta = self.connector.build_connector_meta() + scheduler_output.kv_connector_metadata = meta + return scheduler_output + + if self.connector is not None: + scheduler_output = SchedulerOutput(scheduled_requests=[]) + meta = self.connector.build_connector_meta() + scheduler_output.kv_connector_metadata = meta + return scheduler_output + return None + def update_waiting_for_remote_kv(self, request: InferenceRequest): + self.remote_kv_requests.pop(request.request_id, None) + self.pending_kv_decode_blocks -= ( + request.sampling_params.max_tokens + self.block_size - 1 + ) // self.block_size + if request.request_id in self.failed_receiving_kv_req_ids: + if request.num_computed_tokens: + valid_block_count = request.num_computed_tokens // self.block_size + self.cache_manager.update_blocks_hash( + request.block_table[:valid_block_count], + request.num_local_cached_tokens, + ) + request.slot_mapping = self.cache_manager.update_blocks_slot( + request.block_table, + request.num_computed_tokens, + request.get_prompt_length(), + ) + request.num_local_cached_tokens = request.num_computed_tokens + else: + self.cache_manager.free_blocks(request.block_table) + request.block_table = [] + request.slot_mapping = [] + request.num_local_cached_tokens = 0 + self.failed_receiving_kv_req_ids.discard(request.request_id) + else: + self.cache_manager.update_blocks_hash( + request.block_table, request.num_local_cached_tokens + ) + request.num_local_cached_tokens = request.num_computed_tokens + self.finished_receiving_kv_req_ids.discard(request.request_id) + def complete_requests(self, requests: List[InferenceRequest]): """Handle completed requests and free their blocks.""" for req in requests: @@ -146,8 +353,26 @@ def complete_requests(self, requests: List[InferenceRequest]): RequestStatus.FAILED, RequestStatus.TIMEOUT, ]: - if req.block_table: + delay_free_blocks = False + if self.connector is not None: + delay_free_blocks, _ = self.connector.request_finished( + req, req.block_table, self.block_size + ) + + if req.request_id in self.remote_kv_requests: + self.pending_kv_decode_blocks -= ( + req.sampling_params.max_tokens + self.block_size - 1 + ) // self.block_size + self.remote_kv_requests.pop(req.request_id, None) + if req.request_id in self.finished_receiving_kv_req_ids: + self.finished_receiving_kv_req_ids.discard(req.request_id) + self.failed_receiving_kv_req_ids.discard(req.request_id) + else: + delay_free_blocks = True + if req.block_table and not delay_free_blocks: self.cache_manager.free_blocks(req.block_table) + elif req.block_table and delay_free_blocks: + self.pending_free_blocks[req.request_id] = list(req.block_table) if req.status == RequestStatus.CANCELED: logger.info( @@ -165,7 +390,12 @@ def complete_requests(self, requests: List[InferenceRequest]): # Still running, put back in running queue self.running_queue.sync_q.put(req) - def can_accept_request(self, request: InferenceRequest) -> bool: + def can_accept_request( + self, + request: InferenceRequest, + num_local_computed_tokens: int, + current_prefill_extra_blocks: int = 0, + ) -> bool: total_required_blocks = 0 # Calculate blocks needed for running requests @@ -182,20 +412,83 @@ def can_accept_request(self, request: InferenceRequest) -> bool: self.running_queue.sync_q.put(req) # Calculate blocks needed for the new request - total_length = request.get_prompt_length() + total_length = request.get_prompt_length() - num_local_computed_tokens total_length += request.sampling_params.max_tokens num_blocks_needed = (total_length + self.block_size - 1) // self.block_size total_required_blocks += num_blocks_needed + # Include decode headroom for WAITING_FOR_REMOTE_KVS requests, which + # hold prompt blocks but will also need decode blocks once promoted. + total_required_blocks += self.pending_kv_decode_blocks + + # Include decode headroom for requests accepted earlier in this batch. + total_required_blocks += current_prefill_extra_blocks + # Compare with total usable blocks in cache manager return total_required_blocks <= self.cache_manager.get_total_usable_blocks() + def _get_prefill_extra_blocks(self, request: InferenceRequest) -> int: + total_length = request.get_prompt_length() + total_length += request.sampling_params.max_tokens + total_required_blocks = (total_length + self.block_size - 1) // self.block_size + return max(total_required_blocks - len(request.block_table), 0) + + def update_from_output(self, model_output): + if self.connector is None or model_output.kv_connector_output is None: + return + + finished_recving_req_ids = ( + getattr(model_output.kv_connector_output, "finished_recving", None) or [] + ) + finished_sending_req_ids = ( + getattr(model_output.kv_connector_output, "finished_sending", None) or [] + ) + failed_recving_req_ids = ( + getattr(model_output.kv_connector_output, "failed_recving", None) or [] + ) + invalid_block_ids = ( + getattr(model_output.kv_connector_output, "invalid_block_ids", None) or [] + ) + + for req_id in finished_recving_req_ids: + if req_id in self.pending_free_blocks: + # Aborted request: transfer complete, now safe to free blocks. + self.cache_manager.free_blocks(self.pending_free_blocks.pop(req_id)) + elif req_id in self.remote_kv_requests: + # Active request: mark ready for promotion in schedule(). + self.finished_receiving_kv_req_ids.add(req_id) + # else: already processed or unknown, discard to avoid stale entries. + for req_id in finished_sending_req_ids: + self.cache_manager.free_blocks(self.pending_free_blocks.pop(req_id, [])) + for req_id in failed_recving_req_ids: + # Only track failures for active (non-aborted) requests; aborted + # requests are handled via pending_free_blocks in finished_recving. + if req_id in self.remote_kv_requests: + self.failed_receiving_kv_req_ids.add(req_id) + + if invalid_block_ids: + invalid_set = set(invalid_block_ids) + + for req in self.remote_kv_requests.values(): + start_block_idx = req.num_local_cached_tokens // self.block_size + for i, block_id in enumerate( + req.block_table[start_block_idx:], start=start_block_idx + ): + if block_id in invalid_set: + req.num_computed_tokens = i * self.block_size + break + elif self.failed_receiving_kv_req_ids: + for req_id in self.failed_receiving_kv_req_ids: + req = self.remote_kv_requests[req_id] + req.num_computed_tokens = req.num_local_cached_tokens + def get_cache_stats(self) -> dict: """Get cache statistics.""" return { "num_blocks": self.cache_manager.num_blocks, "block_size": self.cache_manager.block_size, "num_free_blocks": self.cache_manager.get_num_free_blocks(), - "num_req_blocks": len(self.cache_manager.req_block_ids), + "usable_blocks": self.cache_manager.get_total_usable_blocks(), + "num_pending_blocks": len(self.cache_manager.pending_block_ids), "num_used_blocks": len(self.cache_manager.used_block_ids), } diff --git a/python/infinilm/llm/static_scheduler.py b/python/infinilm/llm/static_scheduler.py index d2481bba..589398fe 100644 --- a/python/infinilm/llm/static_scheduler.py +++ b/python/infinilm/llm/static_scheduler.py @@ -33,6 +33,7 @@ def __init__( self.num_requests = len(scheduled_requests) self.is_prefill = is_prefill self.prefix_hit_len = prefix_hit_len + self.kv_connector_metadata = None class StaticScheduler: @@ -222,6 +223,10 @@ def update_cache(self): f"update_cache: cached_block_hashes now has {len(self.cached_block_hashes)} blocks" ) + def update_from_output(self, model_output): + """Static cache has no scheduler-side connector state to update.""" + return None + def complete_requests(self, requests: List[InferenceRequest]): """Handle completed requests.""" for req in requests: diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index 3a94b2ca..6948aa41 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -104,7 +104,7 @@ def _build_model_input_from_static_scheduler_output( Decode phase: - input_ids: Only the last generated token [1, 1] - position_ids: [current_position] (position in full sequence) - - past_kv_lengths: [num_cached_tokens] + - past_kv_lengths: [num_local_cached_tokens] - total_kv_lengths: [total_tokens] """ import infinicore @@ -124,7 +124,11 @@ def _build_model_input_from_static_scheduler_output( input_offsets = [0, len(input_tokens)] else: # Decode: send only the last generated token - last_token = req.generated_token_ids[-1] + last_token = ( + req.generated_token_ids[-1] + if req.generated_token_ids + else req.prompt_token_ids[-1] + ) current_position = req.get_total_length() - 1 input_ids = [[last_token]] position_ids = [[current_position]] @@ -199,7 +203,7 @@ def _build_model_input_from_batch_scheduler_output( current_offset = 0 for req in scheduler_output.scheduled_requests: - num_cached = req.num_cached_tokens + num_cached = req.num_local_cached_tokens if scheduler_output.is_prefill: # Prefill phase req_tokens = req.get_input_tokens() @@ -220,7 +224,11 @@ def _build_model_input_from_batch_scheduler_output( else: # Decode phase seq_len = req.get_total_length() - last_token = req.generated_token_ids[-1] + last_token = ( + req.generated_token_ids[-1] + if req.generated_token_ids + else req.prompt_token_ids[-1] + ) tokens.append(last_token) seq_lens.append(seq_len) diff --git a/python/infinilm/processors/minicpmv_processor.py b/python/infinilm/processors/minicpmv_processor.py index 58da1f12..2f09472b 100644 --- a/python/infinilm/processors/minicpmv_processor.py +++ b/python/infinilm/processors/minicpmv_processor.py @@ -131,7 +131,7 @@ def build_model_inputs( current_offset = 0 for req_id, req in enumerate(scheduler_output.scheduled_requests): - num_cached = req.num_cached_tokens + num_cached = req.num_local_cached_tokens if scheduler_output.is_prefill: # Prefill phase req_tokens = req.get_input_tokens() @@ -231,7 +231,11 @@ def append_mm_data(mm_data__: dict, key__: str, value__): else: # Decode phase seq_len = req.get_total_length() - last_token = req.generated_token_ids[-1] + last_token = ( + req.generated_token_ids[-1] + if req.generated_token_ids + else req.prompt_token_ids[-1] + ) tokens.append(last_token) seq_lens.append(seq_len) diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 3d35941c..11708a18 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -12,11 +12,13 @@ import logging import os import asyncio +from typing import Optional from infinilm.base_config import BaseConfig from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse from infinilm.llm import AsyncLLMEngine, SamplingParams, FinishReason +from infinilm.config import KVTransferConfig logger = logging.getLogger(__name__) @@ -110,6 +112,7 @@ def __init__( enable_graph: bool = False, attn_backend: str = "default", ignore_eos: bool = False, + kv_transfer_config: Optional[KVTransferConfig] = None, ): """Initialize inference server. @@ -131,6 +134,8 @@ def __init__( port: Server port number. enable_graph: Whether to enable graph compiling. attn_backend: Attention backend to use ('default', 'flash-attn'). + ignore_eos: Whether to ignore EOS tokens during generation. + kv_transfer_config: Optional configuration for the KV transfer mechanism. """ self.model_path = model_path # vLLM-like served model id: directory name of model_path @@ -152,6 +157,7 @@ def __init__( self.enable_graph = enable_graph self.attn_backend = attn_backend self.ignore_eos = ignore_eos + self.kv_transfer_config = kv_transfer_config self.engine: AsyncLLMEngine = None @@ -183,6 +189,7 @@ async def lifespan(app: FastAPI): top_k=self.top_k, enable_graph=self.enable_graph, attn_backend=self.attn_backend, + kv_transfer_config=self.kv_transfer_config, ) self.engine.start() logger.info(f"Engine initialized with model at {self.model_path}") @@ -204,7 +211,7 @@ def _register_routes(self, app: FastAPI): async def chat_completions(request: Request): try: data = await request.json() - logger.debug(f"Received request data: {data}") + # logger.debug(f"Received request data: {data}") except Exception as e: logger.error(f"Failed to parse request JSON: {e}") return JSONResponse(content={"error": "Invalid JSON"}, status_code=400) @@ -218,7 +225,6 @@ async def chat_completions(request: Request): data["messages"] = [{"role": "user", "content": data.get("prompt")}] # Normalize messages to handle multimodal content (list format) - # data["messages"] = self._normalize_messages(data.get("messages", [])) data["messages"] = data.get("messages", []) stream = data.get("stream", False) @@ -340,6 +346,7 @@ def pick(key: str, default): async def _stream_chat(self, request_id: str, data: dict, http_request: Request): """Handle streaming chat request.""" req = None + _abort_reason = FinishReason.CANCELED try: messages = data.get("messages", []) @@ -350,7 +357,6 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) sampling_params=sampling_params, request_id=request_id, request_data=data, - http_request=http_request, add_generation_prompt=bool(data.get("add_generation_prompt", True)), chat_template_kwargs=data.get("chat_template_kwargs") or {}, ) @@ -363,10 +369,8 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) # Check client disconnect if await http_request.is_disconnected(): logger.info(f"Client disconnected for request {request_id}") - req.mark_canceled() break - # If stream_request enforces timeout, we can just surface the state to the client. if token_output.finish_reason == FinishReason.TIMEOUT: logger.warning( f"Request {request_id} timed out after {DEFAULT_REQUEST_TIMEOUT}s" @@ -418,15 +422,14 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) break except asyncio.CancelledError: + # Starlette cancelled us (client disconnected); stream_request will be + # aclose()'d automatically via the async-for destructor. logger.info(f"Request {request_id} was cancelled") - if req: - req.mark_canceled() raise except Exception as e: logger.error(f"Stream error for {request_id}: {e}", exc_info=True) - if req: - req.mark_failed() + _abort_reason = FinishReason.ERROR error_chunk = json.dumps( chunk_json( request_id, @@ -439,15 +442,16 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) yield f"data: {error_chunk}\n\n" finally: + # Unified abort: reason is ERROR if we got here via Exception, else CANCELED. + # req.close() is handled by stream_request.finally. if req and not req.is_finished(): - req.mark_canceled() - if req: - await req.close() - yield "data: [DONE]\n\n" + self.engine.add_aborted_req(req, _abort_reason) + yield "data: [DONE]\n\n" async def _chat(self, request_id: str, data: dict, http_request: Request): """Handle non-streaming chat request.""" req = None + _abort_reason = FinishReason.CANCELED try: messages = data.get("messages", []) @@ -458,7 +462,6 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): sampling_params=sampling_params, request_id=request_id, request_data=data, - http_request=http_request, add_generation_prompt=bool(data.get("add_generation_prompt", True)), chat_template_kwargs=data.get("chat_template_kwargs") or {}, ) @@ -473,7 +476,6 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): # Check client disconnect if await http_request.is_disconnected(): logger.info(f"Client disconnected for request {request_id}") - req.mark_canceled() break # Request-level timeout is handled inside stream_request. @@ -509,21 +511,18 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): except asyncio.CancelledError: logger.info(f"Request {request_id} was cancelled") - if req: - req.mark_canceled() raise except Exception as e: logger.error(f"Chat error for {request_id}: {e}", exc_info=True) - if req: - req.mark_failed() + _abort_reason = FinishReason.ERROR return JSONResponse(content={"error": str(e)}, status_code=500) finally: + # Unified abort: reason is ERROR if we got here via Exception, else CANCELED. + # req.close() is handled by stream_request.finally. if req and not req.is_finished(): - req.mark_canceled() - if req: - await req.close() + self.engine.add_aborted_req(req, _abort_reason) def _convert_finish_reason(self, reason: FinishReason) -> str: """Convert FinishReason enum to string.""" @@ -551,11 +550,29 @@ def setup_logging(log_level: str = "INFO"): ) +def parse_kv_transfer_config(kv_transfer_config_str: str) -> KVTransferConfig: + """Parse JSON string into KVTransferConfig.""" + kv_dict = json.loads(kv_transfer_config_str) + if not isinstance(kv_dict, dict): + raise ValueError("--kv-transfer-config must be a JSON object") + + return KVTransferConfig( + kv_connector=kv_dict.get("kv_connector", None), + engine_id=kv_dict.get("engine_id", None), + kv_role=kv_dict.get("kv_role", None), + kv_connector_extra_config=kv_dict.get("kv_connector_extra_config", None), + ) + + def main(): cfg = BaseConfig() setup_logging(cfg.log_level) device = cfg.get_device_str(cfg.device) + kv_transfer_config = None + if cfg.kv_transfer_config: + kv_transfer_config = parse_kv_transfer_config(cfg.kv_transfer_config) + server = InferenceServer( model_path=cfg.model, device=device, @@ -575,6 +592,7 @@ def main(): enable_graph=cfg.enable_graph, attn_backend=cfg.attn, ignore_eos=cfg.ignore_eos, + kv_transfer_config=kv_transfer_config, ) server.start() diff --git a/python/infinilm/server/mooncake_proxy_server.py b/python/infinilm/server/mooncake_proxy_server.py new file mode 100644 index 00000000..639bf942 --- /dev/null +++ b/python/infinilm/server/mooncake_proxy_server.py @@ -0,0 +1,410 @@ +import argparse +import asyncio +import httpx +import ipaddress +import itertools +import json +import os +import urllib +import uuid +import uvicorn +import logging + +from contextlib import asynccontextmanager +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse + +logger = logging.getLogger(__name__) + + +def maybe_wrap_ipv6_address(hostname): + try: + ipaddress.IPv6Address(hostname) + return f"[{hostname}]" + except ValueError: + return hostname + + +def make_http_address(hostname, port): + return f"http://{hostname}:{port}" + + +async def wait_for_cluster_ready(prefill_clients, decode_clients, ready_event): + for prefill_client in prefill_clients: + while True: + try: + response = await prefill_client["client"].get("/health") + response.raise_for_status() + except Exception as e: + await asyncio.sleep(1) # Wait before retrying + logger.warning( + "waiting for handshake with prefill server: %s", + prefill_client, + ) + continue + + response = await prefill_client["client"].get( + prefill_client["bootstrap_addr"] + "/query" + ) + response.raise_for_status() + data = response.json() + break + + logger.warning("successfully handshake with prefill server!") + + for dp_rank, engine_info in data.items(): + prefill_client["dp_engine_id"][int(dp_rank)] = engine_info["engine_id"] + prefill_client["dp_size"] = len(data) + + for decode_client in decode_clients: + while True: + try: + response = await decode_client["client"].get("/health") + response.raise_for_status() + logger.warning("successfully handshake with decode server!") + break + except Exception as e: + await asyncio.sleep(1) # Wait before retrying + logger.warning( + "waiting for handshake with decode server: %s", + decode_client, + ) + + ready_event.set() # Signal that all prefiller and decoder info has been collected + + +def prefiller_cycle(prefill_clients): + while True: + for prefill_client in prefill_clients: + dp_size = prefill_client["dp_size"] + for dp_rank in range(dp_size): + yield prefill_client, dp_rank + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + app.state.ready = asyncio.Event() + + # Create prefill clients + for prefill_url, bootstrap_port in global_args.prefill: + parsed_url = urllib.parse.urlparse(prefill_url) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + app.state.prefill_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=prefill_url, + limits=httpx.Limits( + max_connections=None, max_keepalive_connections=None + ), + ), + "url": prefill_url, + "bootstrap_addr": make_http_address(hostname, bootstrap_port or 9600), + "dp_engine_id": {}, + } + ) + + logger.info("global_args: %s", global_args) + # Create decode clients + for decode_url in global_args.decode: + parsed_url = urllib.parse.urlparse(decode_url) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + app.state.decode_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=decode_url, + limits=httpx.Limits( + max_connections=None, max_keepalive_connections=None + ), + ), + } + ) + + asyncio.create_task( + wait_for_cluster_ready( + app.state.prefill_clients, app.state.decode_clients, app.state.ready + ) + ) + app.state.prefill_iterator = prefiller_cycle(app.state.prefill_clients) + app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) + + yield + + for client_info in app.state.prefill_clients: + await client_info["client"].aclose() + + for client_info in app.state.decode_clients: + await client_info["client"].aclose() + + +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Mooncake Proxy Server") + + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + + # Prefill services + parser.add_argument( + "--prefill", + nargs="+", + action="append", + dest="prefill_url_list", + metavar=("URL", "Bootstrap"), + help=( + "Prefill service URL and optional bootstrap file" + "Can be specified multiple times for multiple services " + "(e.g., --prefill http://localhost:9000/ prefill_bootstrap.json)" + ), + ) + + # Decoder services + parser.add_argument( + "--decode", + nargs=1, + action="append", + dest="decode_url_list", + metavar="URL", + help=( + "Decoder service URL. Can be specified multiple times for multiple services " + "(e.g., --decode http://localhost:9001/)" + ), + ) + + args = parser.parse_args() + args.prefill = _parse_prefill_urls(args.prefill_url_list) + args.decode = _parse_decode_urls(args.decode_url_list) + + return args + + +def _parse_prefill_urls(prefill_url_list): + if not prefill_url_list: + return [] + + prefill_urls = [] + + for url in prefill_url_list: + prefill_url = url[0] + + if len(url) > 1: + bootstrap_port_str = url[1] + if bootstrap_port_str.lower() == "none": + bootstrap_port = None + else: + try: + bootstrap_port = int(bootstrap_port_str) + except ValueError as e: + raise ValueError( + f"Invalid bootstrap port value: {bootstrap_port_str}. Must be an integer or 'none'." + ) from e + else: + bootstrap_port = None + + prefill_urls.append((prefill_url, bootstrap_port)) + + return prefill_urls + + +def _parse_decode_urls(decode_url_list): + if not decode_url_list: + return [] + + return [url[0] for url in decode_url_list] + + +def get_next_client(app: FastAPI, service_type: str): + if service_type == "prefill": + return next(app.state.prefill_iterator) + elif service_type == "decode": + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def forward_request(client_info: dict, api: str, req_data: dict, headers: dict): + """Forward a request to a backend node and yield response bytes. + Modeled after disagg_proxy_demo.forward_request, adapted for httpx. + """ + try: + async with client_info["client"].stream( + "POST", api, json=req_data, headers=headers + ) as response: + if 200 <= response.status_code < 300 or 400 <= response.status_code < 500: + async for chunk in response.aiter_bytes(): + yield chunk + else: + error_bytes = await response.aread() + try: + error_body = json.loads(error_bytes) + except Exception: + error_body = error_bytes.decode(errors="replace") + logger.error( + "Backend %s%s returned status %d: %s", + client_info.get("url", ""), + api, + response.status_code, + error_body, + ) + raise HTTPException( + status_code=response.status_code, + detail=f"Backend error {response.status_code}: {error_body}", + ) + except HTTPException: + raise + except (httpx.ConnectError, httpx.RemoteProtocolError, httpx.ReadError) as e: + logger.error( + "Connection error to %s%s: %s", + client_info.get("url", ""), + api, + e, + ) + raise HTTPException( + status_code=502, detail=f"Backend connection error: {e}" + ) from e + except httpx.HTTPStatusError as e: + logger.error( + "HTTP error from %s%s: %s", + client_info.get("url", ""), + api, + e, + ) + raise HTTPException(status_code=e.response.status_code, detail=str(e)) from e + except Exception as e: + logger.error( + "Unexpected error forwarding to %s%s: %s", + client_info.get("url", ""), + api, + e, + exc_info=True, + ) + raise HTTPException(status_code=500, detail=str(e)) from e + + +async def send_request( + client_info: dict, dp_rank: int, api: str, req_data: dict, request_id: str +): + req_data = req_data.copy() + req_data["kv_transfer_params"] = { + "do_remote_prefill": False, + "do_remote_decode": True, + "transfer_id": f"xfer-{request_id}", + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + "X-data-parallel-rank": str(dp_rank), + # Uncomment in high-concurrency or long-context scenarios (minor performance overhead): + "Connection": "close", + } + # Consume prefill response via forward_request for unified error handling + async for _ in forward_request(client_info, api, req_data, headers): + pass + + +async def stream_response( + prefill_client_info: dict, + prefill_dp_rank: int, + decode_client_info: dict, + api: str, + req_data: dict, + request_id: str, +): + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + # Uncomment in high-concurrency or long-context scenarios (minor performance overhead): + "Connection": "close", + } + + req_data["kv_transfer_params"] = { + "do_remote_prefill": True, + "do_remote_decode": False, + "remote_bootstrap_addr": prefill_client_info["bootstrap_addr"], + "remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank], + "transfer_id": f"xfer-{request_id}", + } + + # Delegate to forward_request for unified error handling + async for chunk in forward_request(decode_client_info, api, req_data, headers): + yield chunk + + +async def _handle_completions(api: str, request: Request): + if not request.app.state.ready.is_set(): + raise HTTPException(status_code=503, detail="Service Unavailable") + + try: + req_data = await request.json() + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") + + request_id = f"cmpl-{uuid.uuid4().hex}" + + prefill_client_info, prefill_dp_rank = get_next_client(request.app, "prefill") + asyncio.create_task( + send_request(prefill_client_info, prefill_dp_rank, api, req_data, request_id) + ) + + decode_client_info = get_next_client(request.app, "decode") + + async def generate_stream(): + try: + async for chunk in stream_response( + prefill_client_info, + prefill_dp_rank, + decode_client_info, + api, + req_data, + request_id=request_id, + ): + yield chunk + except HTTPException as e: + logger.error( + "Decode error for %s: HTTP %d - %s", request_id, e.status_code, e.detail + ) + yield json.dumps( + {"error": {"message": e.detail, "code": e.status_code}} + ).encode() + except Exception as e: + logger.error("Stream error for %s: %s", request_id, e, exc_info=True) + yield json.dumps({"error": {"message": str(e), "code": 500}}).encode() + + return StreamingResponse(generate_stream(), media_type="application/json") + + +@app.post("/v1/completions") +async def handle_v1_completions(request: Request): + return await _handle_completions("/v1/completions", request) + + +@app.post("/chat/completions") +async def handle_chat_completions(request: Request): + return await _handle_completions("/chat/completions", request) + + +@app.post("/v1/chat/completions") +async def handle_v1_chat_completions(request: Request): + return await _handle_completions("/v1/chat/completions", request) + + +if __name__ == "__main__": + global global_args + global_args = parse_args() + + uvicorn.run(app, host=global_args.host, port=global_args.port)