Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,22 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
this->compile();
}

std::vector<std::vector<infinicore::Tensor>> InferEngine::get_kv_cache() {
std::vector<std::vector<infinicore::Tensor>> kv_cache_list;
if (workers_.empty()) {
throw std::runtime_error("InferEngine::get_cache_vec: no workers");
}

kv_cache_list.reserve(workers_.size());
for (auto &worker : workers_) {
kv_cache_list.push_back(std::move(worker->get_kv_cache()));
}

for (auto &worker : workers_) {
worker->wait();
}

return kv_cache_list;
}

} // namespace infinilm::engine
2 changes: 2 additions & 0 deletions csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class InferEngine {

void reset_cache(const cache::CacheConfig *new_config);

std::vector<std::vector<infinicore::Tensor>> get_kv_cache();

~InferEngine();

const distributed::DistConfig &get_dist_config() const;
Expand Down
16 changes: 16 additions & 0 deletions csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,22 @@ void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
cv_.notify_all();
}

//------------------------------------------------------
// get kv cache
//------------------------------------------------------
std::vector<infinicore::Tensor> RankWorker::get_kv_cache() {
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_ || should_exit_; });

if (should_exit_) {
throw std::runtime_error("RankWorker stopped; cannot get_cache_vec");
}

ASSERT(forward_context_.kv_cache_vec.size() > 0 && "RankWorker::get_kv_cache(): kv_cache_vec is empty");

return forward_context_.kv_cache_vec;
}

//------------------------------------------------------
// close -- request shutdown and join thread
//------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class RankWorker {
// Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig *new_config);

std::vector<infinicore::Tensor> get_kv_cache();

// Compile the model graph if enabled.
void compile();

Expand Down
5 changes: 5 additions & 0 deletions csrc/pybind11/engine/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,17 @@ inline void bind_infer_engine(py::module &m) {
.def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)")
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output {
// IMPORTANT: Release the GIL before calling forward() to allow other Python threads
// to run concurrently during inference (which may block for a long time).
// Do NOT remove this — without it, the GIL is held throughout inference and will
// deadlock or stall any other Python thread (e.g., request handling, scheduling).
py::gil_scoped_release release;
return self.forward(input);
},
"Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_kv_cache", &InferEngine::get_kv_cache, "Get per-rank kv cache list")
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
auto cfg = self.get_cache_config();
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })
Expand Down
15 changes: 15 additions & 0 deletions python/infinilm/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(self):
self.port = self.args.port
self.endpoint = self.args.endpoint
self.ignore_eos = self.args.ignore_eos
# PD separation (KV transfer)
self.kv_transfer_config = self.args.kv_transfer_config

# Multimodal parameters
self.image = self.args.image
Expand Down Expand Up @@ -268,6 +270,19 @@ def _add_common_args(self):
help="image path for multimodal models",
)

# ---- PD separation arguments ----
self.parser.add_argument(
"--kv-transfer-config",
type=str,
default=None,
help=(
"JSON object for KVTransferConfig. Allowed keys only: "
"kv_connector, engine_id, kv_role, kv_connector_extra_config (omit any for defaults). "
"Example: "
'\'{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}\''
),
)

def get_device_str(self, device):
"""Convert device name to backend string (cuda/cpu/musa/mlu)"""
DEVICE_STR_MAP = {
Expand Down
4 changes: 4 additions & 0 deletions python/infinilm/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .engine_config import EngineConfig
from .kv_transfer import KVTransferConfig

__all__ = ["EngineConfig", "KVTransferConfig"]
53 changes: 53 additions & 0 deletions python/infinilm/config/engine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from dataclasses import dataclass
from typing import Optional
from infinilm.config.kv_transfer import KVTransferConfig


@dataclass
class EngineConfig:
"""Configuration for LLM Engine.

Attributes:
model_path: Path to the model directory.
device: Device type string ('cpu', 'cuda', 'mlu', etc.).
dtype: Data type string ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
cache_type: Cache type ('paged' or 'static').
max_batch_size: Maximum batch size for inference (only for paged cache).
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
skip_load: Whether to skip loading model weights (for testing).
"""

model_path: str
device: str = "cuda"
dtype: str = "float16"
tensor_parallel_size: int = 1
cache_type: str = "paged" # "paged" or "static"
max_batch_size: int = 16
max_tokens: int = 4096
num_blocks: int = 512
block_size: int = 256
max_cache_len: int = 4096
temperature: float = 1.0
top_p: float = 0.8
top_k: int = 1
enable_graph: bool = False
attn_backend: str = "default"
skip_load: bool = False
kv_transfer_config: Optional[KVTransferConfig] = None

def __post_init__(self) -> None:
if (
self.kv_transfer_config is not None
and self.kv_transfer_config.kv_connector
and self.cache_type != "paged"
):
raise ValueError("kv_transfer_config requires cache_type='paged'")
61 changes: 61 additions & 0 deletions python/infinilm/config/kv_transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright 2026 InfiniLM Contributors

import uuid
from dataclasses import dataclass, field
from typing import Optional
import os

KV_ROLE_CHOICES = frozenset({"kv_producer", "kv_consumer"})


@dataclass
class KVTransferConfig:
"""Configuration for KV cache transfer in prefill/decode (P/D) separation.

Attributes:
kv_connector: Name of the KV connector to use (e.g. 'MooncakeConnector').
None disables KV transfer.
kv_role: Role of this node: 'kv_producer' (prefill) or 'kv_consumer' (decode).
engine_id: Unique identifier for this engine instance used in KV transfers.
Auto-generated (UUID) if not provided.
kv_connector_extra_config: Extra configuration dict passed to the connector backend.
"""

kv_connector: Optional[str] = None
kv_role: Optional[str] = None
engine_id: Optional[str] = None
kv_connector_extra_config: Optional[dict] = field(default_factory=dict)

def __post_init__(self) -> None:
if self.kv_connector is not None and self.kv_role is None:
raise ValueError("Please specify kv_role when kv_connector is set.")

if self.kv_role is not None and self.kv_role not in KV_ROLE_CHOICES:
raise ValueError(
f"Unsupported kv_role: {self.kv_role!r}. "
f"Supported roles are {sorted(KV_ROLE_CHOICES)}"
)

if self.engine_id is None:
self.engine_id = f"{self.kv_role}_" + str(uuid.uuid4())

self.kv_connector_extra_config = dict(self.kv_connector_extra_config or {})
self.kv_connector_extra_config.setdefault("mooncake_protocol", "rdma")

allowed_extra_config_keys = frozenset({"mooncake_protocol", "num_workers"})
unknown_keys = set(self.kv_connector_extra_config.keys()) - allowed_extra_config_keys
if unknown_keys:
raise ValueError(
f"Unsupported kv_connector_extra_config keys: {sorted(unknown_keys)}. "
f"Supported keys are {sorted(allowed_extra_config_keys)}"
)

mooncake_protocol = self.kv_connector_extra_config["mooncake_protocol"]
if mooncake_protocol not in ["tcp", "rdma"]:
raise ValueError(f"only support tcp or rdma, but got {mooncake_protocol}")

if mooncake_protocol == "tcp":
# NOTE: force use tcp for Mooncake
os.environ["MC_FORCE_TCP"] = "true"
15 changes: 15 additions & 0 deletions python/infinilm/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,18 @@ def load_state_dict(self, state_dict, strict=None):

def process_weights_after_loading(self):
super().process_weights_after_loading()

def get_kv_cache(self) -> list[list[infinicore.Tensor]]:
"""
get per-rank kv cache.
"""
kv_cache_list = super().get_kv_cache()
infinicore.sync_device()

result = []
for rank_idx, kv_caches_per_rank in enumerate(kv_cache_list):
result_rank = []
for layer_idx, layer_kv in enumerate(kv_caches_per_rank):
result_rank.append(infinicore.Tensor(layer_kv))
result.append(result_rank)
return result
36 changes: 36 additions & 0 deletions python/infinilm/kv_connector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
KV connector package.

This module:
- Exposes core KV connector abstractions (base, role, metadata)
- Provides the KVConnectorFactory
- Registers built-in connectors (e.g. MooncakeConnector)

Note:
Importing this module will trigger connector registration.
"""

from infinilm.kv_connector.base import (
KVConnectorBase,
KVConnectorRole,
KVConnectorMetadata,
KVConnectorHandshakeMetadata,
KVConnectorWorkerMetadata,
)
from infinilm.kv_connector.factory import KVConnectorFactory

KVConnectorFactory.register_connector(
"MooncakeConnector",
"infinilm.kv_connector.mooncake.mooncake_connector",
"MooncakeConnector",
)


__all__ = [
"KVConnectorBase",
"KVConnectorRole",
"KVConnectorMetadata",
"KVConnectorHandshakeMetadata",
"KVConnectorWorkerMetadata",
"KVConnectorFactory",
]
Loading