Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions KVCOMM/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ class KVCommConfig:
threshold: float = 0.3
max_anchor_num: int = 20
window_size: int = 5
top_k: int | None = None
thread_pool_workers: int = 8
worker_timeout: float = 30.0
resident_anchor_summary: str | None = None
resident_anchor_top_n: int = 0

@classmethod
def from_env(cls) -> "KVCommConfig":
Expand All @@ -28,8 +31,15 @@ def from_env(cls) -> "KVCommConfig":
threshold=float(os.environ.get("THRESHOLD", cls.threshold)),
max_anchor_num=int(os.environ.get("MAX_ANCHOR_NUM", cls.max_anchor_num)),
window_size=int(os.environ.get("WINDOW_SIZE", cls.window_size)),
top_k=(
int(os.environ.get("KVCOMM_TOP_K", os.environ.get("TOP_K")))
if os.environ.get("KVCOMM_TOP_K", os.environ.get("TOP_K")) is not None
else cls.top_k
),
thread_pool_workers=int(os.environ.get("KVCOMM_THREAD_WORKERS", cls.thread_pool_workers)),
worker_timeout=float(os.environ.get("KVCOMM_WORKER_TIMEOUT", cls.worker_timeout)),
resident_anchor_summary=os.environ.get("KVCOMM_RESIDENT_ANCHOR_SUMMARY"),
resident_anchor_top_n=int(os.environ.get("KVCOMM_RESIDENT_ANCHOR_TOP_N", cls.resident_anchor_top_n)),
).validate()

def apply_overrides(self, **overrides: Any) -> "KVCommConfig":
Expand All @@ -43,8 +53,12 @@ def apply_overrides(self, **overrides: Any) -> "KVCommConfig":

def validate(self) -> "KVCommConfig":
"""Validate value ranges and return self."""
if self.top_k is not None and self.top_k <= 0:
raise ValueError("top_k must be positive when provided")
if self.thread_pool_workers <= 0:
raise ValueError("thread_pool_workers must be positive")
if self.worker_timeout <= 0:
raise ValueError("worker_timeout must be positive")
if self.resident_anchor_top_n < 0:
raise ValueError("resident_anchor_top_n must be non-negative")
return self
160 changes: 126 additions & 34 deletions KVCOMM/llm/gpt_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from KVCOMM.llm.config import KVCommConfig

from KVCOMM.llm.token_ops import *
from KVCOMM.llm.kvcomm_engine import KVCOMMEngine, _RequestState
from KVCOMM.llm.kvcomm_engine import KVCOMMEngine, _RequestState, _move_tensor_tree
from KVCOMM.utils.metrics import GenerationResult
from KVCOMM.utils.log import logger

Expand All @@ -48,6 +48,26 @@ def _escape_loguru_markup(text: Optional[str]) -> str:
return text.replace("<", "\\<")


def _hf_model_load_kwargs(model_name: str) -> Tuple[torch.dtype, Optional[str]]:
"""torch_dtype and device_map for local HF load.

Non-Llama models previously defaulted to float32 on CUDA, which roughly
doubles VRAM versus fp16/bf16 and commonly OOMs 7B-class weights on 24GB
cards during from_pretrained. device_map uses ``auto`` so visible GPUs
(e.g. via CUDA_VISIBLE_DEVICES) are used without hard-coding cuda:0.
"""
if torch.cuda.is_available():
mn = model_name.lower()
if "llama" in mn:
dtype = torch.float16
elif torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
else:
dtype = torch.float16
return dtype, "auto"
return torch.float32, None


_LATENCY_IO_LOCK = threading.Lock()


Expand Down Expand Up @@ -214,6 +234,8 @@ def __init__(self, model_name: str, prefix: str = None, config: KVCommConfig | N
self.model_name = model_name

self.config = (config or KVCommConfig.from_env()).validate()
self.anchor_device = torch.device("cpu")
self.kv_storage_device = torch.device("cpu")
self._ensure_thread_pool(self.config.thread_pool_workers)
self.kv_engine = KVCOMMEngine(self)

Expand Down Expand Up @@ -550,6 +572,13 @@ def _ensure_global_input_buckets(self) -> Dict[str, Dict[str, Any]]:
store.setdefault("input_drop_num", {})
return store

def _offload_kv_payload(self, value: Any) -> Any:
"""Store shared KV payloads on CPU; callers materialize copies for compute."""
return _move_tensor_tree(value, self.kv_storage_device)

def _materialize_kv_payload(self, value: Any) -> Any:
return _move_tensor_tree(value, self.model.device)

def has_prefix_initialized(self, agent_id: str) -> bool:
"""Check if prefix KV has been initialized for an agent."""
return LLMChat._initialization.get(agent_id, False)
Expand Down Expand Up @@ -627,14 +656,6 @@ def update_condition_anchor(
condition_cache = generated.past_key_values


for key_name, value in (
("condition", condition_cache),
("condition_ids", token_ids),
("condition_drop_num", drop_num),
):
bucket = owner_memory.setdefault(key_name, {})
bucket.setdefault(message, []).append(value)

anchor_store = state.anchors.setdefault(anchor_key, {})
cond_anchor_list = list(anchor_store.values())
cond_len_bucket = state.anchor_len_dict.setdefault(anchor_key, {})
Expand All @@ -654,8 +675,24 @@ def update_condition_anchor(
anchor_kv_cache_list=cond_anchor_list,
anchor_len_list=anchor_len_list,
anchor_activated_list=anchor_activated_list,
request_uid=request_uid,
ph_id=anchor_key,
message=message,
anchor_labels=list(anchor_store.keys()),
log_events=True,
)

for key_name, value in (
("condition", condition_cache),
("condition_ids", token_ids),
("condition_drop_num", drop_num),
):
bucket = owner_memory.setdefault(key_name, {})
if key_name.endswith("_drop_num"):
bucket.setdefault(message, []).append(value)
else:
bucket.setdefault(message, []).append(self._offload_kv_payload(value))

cond_flag_bucket = state.anchor_dict.setdefault(anchor_key, {})
cond_flag_bucket[message] = prob

Expand Down Expand Up @@ -771,13 +808,6 @@ def update_input_anchor(
)
input_cache = output.past_key_values

global_buckets = self._ensure_global_input_buckets()
global_buckets["input"].setdefault(message, []).append(
input_cache.copy().slice_(start=0, end=token_ids["input_ids"].shape[-1])
)
global_buckets["input_ids"].setdefault(message, []).append(token_ids)
global_buckets["input_drop_num"].setdefault(message, []).append(drop_num)

anchor_store = state.anchors.setdefault(anchor_namespace, {})
input_anchor_list = list(anchor_store.values())
uq_len_bucket = state.anchor_len_dict.setdefault(anchor_namespace, {})
Expand All @@ -798,11 +828,25 @@ def update_input_anchor(
anchor_len_list=anchor_len_list,
anchor_activated_list=anchor_activated_list,
test_time=test_time,
request_uid=request_uid,
ph_id=anchor_namespace,
message=message,
anchor_labels=list(anchor_store.keys()),
log_events=True,
)
logger.opt(colors=True).debug(
f"<magenta>Anchor prediction for input '{safe_message}'</magenta>: {prob}"
)

global_buckets = self._ensure_global_input_buckets()
global_buckets["input"].setdefault(message, []).append(
self._offload_kv_payload(
input_cache.copy().slice_(start=0, end=token_ids["input_ids"].shape[-1])
)
)
global_buckets["input_ids"].setdefault(message, []).append(self._offload_kv_payload(token_ids))
global_buckets["input_drop_num"].setdefault(message, []).append(drop_num)

state.anchor_dict.setdefault(anchor_namespace, {})[message] = prob
global_bucket = state.global_anchor_info.setdefault(anchor_namespace, {})
if not prob:
Expand Down Expand Up @@ -836,6 +880,7 @@ async def generate_for_agent(
) -> GenerationResult:
"""Generate a response using the requested strategy with sensible fallbacks."""
latency_target = output_dir or kwargs.get("output_dir")
self.kv_engine.configure_anchor_event_logging(latency_target)
if preferred_mode == "dense_prefill":
mode = "dense_prefill"
elif self.has_active_anchor(request_uid, message):
Expand Down Expand Up @@ -918,8 +963,8 @@ async def prepare_prefix_kv_segments(self, node_id: str, prefix: str, user_promp

if type_ == "text":
seg_kv = base_kv.slice(start=s, end=e)
segment_kv_list.append(seg_kv)
token_id_list.append(token_id)
segment_kv_list.append(self._offload_kv_payload(seg_kv))
token_id_list.append(self._offload_kv_payload(token_id))
self._shared_kv_cache_memory[node_id]["prefix"] = LLMChat._shared_kv_cache_memory[node_id]["prefix"] = segment_kv_list
self._shared_kv_cache_memory[node_id]["placeholder_info"] = LLMChat._shared_kv_cache_memory[node_id]["placeholder_info"] = placeholder_info
self._shared_kv_cache_memory[node_id]["token_ids"] = LLMChat._shared_kv_cache_memory[node_id]["token_ids"] = token_id_list
Expand All @@ -931,11 +976,12 @@ def _initialize_shared_resources(self):
with LLMChat._model_lock:
if LLMChat._shared_model is None:
LLMChat._shared_tokenizer = AutoTokenizer.from_pretrained(self.model_name)
load_dtype, device_map = _hf_model_load_kwargs(self.model_name)
LLMChat._shared_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if 'llama' in self.model_name else torch.float32,
torch_dtype=load_dtype,
low_cpu_mem_usage=True,
device_map="cuda:0",
device_map=device_map,
trust_remote_code=True
)
logger.info("Model {} loaded and shared across instances.", self.model_name)
Expand Down Expand Up @@ -1194,8 +1240,14 @@ async def agen_kvcomm(
message = messages

prefix_store = self._shared_kv_cache_memory[self.node_id]
prefix_kv_list: List[DynamicCache] = prefix_store.get("prefix", [])
prefix_token_ids: List[Dict[str, torch.Tensor]] = prefix_store.get("token_ids", [])
prefix_kv_list: List[DynamicCache] = [
self._materialize_kv_payload(cache)
for cache in prefix_store.get("prefix", [])
]
prefix_token_ids: List[Dict[str, torch.Tensor]] = [
self._materialize_kv_payload(token_ids)
for token_ids in prefix_store.get("token_ids", [])
]
placeholder_info_map = prefix_store.get("placeholder_info")
if not prefix_kv_list:
raise RuntimeError(
Expand Down Expand Up @@ -1370,12 +1422,14 @@ async def agen_kvcomm(
]
anchor_active_list: List[int] = list(anchor_info_bucket.values())

resp.setdefault(message, []).append(response_kv_cache)
resp.setdefault(message, []).append(self._offload_kv_payload(response_kv_cache))
resp_ids.setdefault(message, []).append(
{
"input_ids": response_tokens,
"attention_mask": response_mask,
}
self._offload_kv_payload(
{
"input_ids": response_tokens,
"attention_mask": response_mask,
}
)
)
resp_drop.setdefault(message, []).append(0)

Expand All @@ -1390,6 +1444,11 @@ async def agen_kvcomm(
anchor_kv_cache_list=response_anchor_list,
anchor_len_list=anchor_len_list,
anchor_activated_list=anchor_active_list,
request_uid=request_uid,
ph_id=current_key,
message=message,
anchor_labels=list(anchor_bucket.keys()),
log_events=True,
)
safe_message = _escape_loguru_markup(message)
logger.opt(colors=True).debug(
Expand Down Expand Up @@ -1502,8 +1561,14 @@ async def agen_kvcomm_time_test(
message = messages

prefix_store = self._shared_kv_cache_memory[self.node_id]
prefix_kv_list: List[DynamicCache] = prefix_store.get("prefix", [])
prefix_token_ids: List[Dict[str, torch.Tensor]] = prefix_store.get("token_ids", [])
prefix_kv_list: List[DynamicCache] = [
self._materialize_kv_payload(cache)
for cache in prefix_store.get("prefix", [])
]
prefix_token_ids: List[Dict[str, torch.Tensor]] = [
self._materialize_kv_payload(token_ids)
for token_ids in prefix_store.get("token_ids", [])
]
placeholder_info_map = prefix_store.get("placeholder_info")
if not prefix_kv_list:
raise RuntimeError(
Expand Down Expand Up @@ -1649,6 +1714,26 @@ async def agen_kvcomm_time_test(
ttft_value = kvcomm_ttft_value
else:
ttft_value = dense_prefill_ttft
if mode == "dense_prefill":
base_cache = merged_prefix_kv
real_cache = full_kv_cache.slice(start=0, end=prefix_token_length)
real_placeholder_cache, real_prefix_cache = real_cache.split_cache_by_placeholders(
placeholder_indices
)
base_placeholder_cache, base_prefix_cache = base_cache.split_cache_by_placeholders(
placeholder_indices
)
self.kv_engine.set_anchor(
request_uid,
message,
ph_id_list,
real_placeholder_cache,
real_prefix_cache,
base_placeholder_cache,
base_prefix_cache,
max_anchor_num=max_anchor_num,
window_length=window_length,
)
response_kv_cache = full_kv_cache.slice_(start=prefix_token_length)
response_kv_cache = self.kv_engine.apply_rotary_pos_emb(
response_kv_cache,
Expand Down Expand Up @@ -1676,12 +1761,14 @@ async def agen_kvcomm_time_test(
]
anchor_active_list: List[int] = list(anchor_info_bucket.values())

resp.setdefault(message, []).append(response_kv_cache)
resp.setdefault(message, []).append(self._offload_kv_payload(response_kv_cache))
resp_ids.setdefault(message, []).append(
{
"input_ids": response_tokens,
"attention_mask": response_mask,
}
self._offload_kv_payload(
{
"input_ids": response_tokens,
"attention_mask": response_mask,
}
)
)
resp_drop.setdefault(message, []).append(0)

Expand All @@ -1697,6 +1784,11 @@ async def agen_kvcomm_time_test(
anchor_len_list=anchor_len_list,
anchor_activated_list=anchor_active_list,
test_time=True,
request_uid=request_uid,
ph_id=current_key,
message=message,
anchor_labels=list(anchor_bucket.keys()),
log_events=True,
)
safe_message = _escape_loguru_markup(message)
logger.opt(colors=True).debug(
Expand Down
Loading