diff --git a/KVCOMM/llm/config.py b/KVCOMM/llm/config.py
index 1bf6ed8..2e4345c 100644
--- a/KVCOMM/llm/config.py
+++ b/KVCOMM/llm/config.py
@@ -18,8 +18,11 @@ class KVCommConfig:
threshold: float = 0.3
max_anchor_num: int = 20
window_size: int = 5
+ top_k: int | None = None
thread_pool_workers: int = 8
worker_timeout: float = 30.0
+ resident_anchor_summary: str | None = None
+ resident_anchor_top_n: int = 0
@classmethod
def from_env(cls) -> "KVCommConfig":
@@ -28,8 +31,15 @@ def from_env(cls) -> "KVCommConfig":
threshold=float(os.environ.get("THRESHOLD", cls.threshold)),
max_anchor_num=int(os.environ.get("MAX_ANCHOR_NUM", cls.max_anchor_num)),
window_size=int(os.environ.get("WINDOW_SIZE", cls.window_size)),
+ top_k=(
+ int(os.environ.get("KVCOMM_TOP_K", os.environ.get("TOP_K")))
+ if os.environ.get("KVCOMM_TOP_K", os.environ.get("TOP_K")) is not None
+ else cls.top_k
+ ),
thread_pool_workers=int(os.environ.get("KVCOMM_THREAD_WORKERS", cls.thread_pool_workers)),
worker_timeout=float(os.environ.get("KVCOMM_WORKER_TIMEOUT", cls.worker_timeout)),
+ resident_anchor_summary=os.environ.get("KVCOMM_RESIDENT_ANCHOR_SUMMARY"),
+ resident_anchor_top_n=int(os.environ.get("KVCOMM_RESIDENT_ANCHOR_TOP_N", cls.resident_anchor_top_n)),
).validate()
def apply_overrides(self, **overrides: Any) -> "KVCommConfig":
@@ -43,8 +53,12 @@ def apply_overrides(self, **overrides: Any) -> "KVCommConfig":
def validate(self) -> "KVCommConfig":
"""Validate value ranges and return self."""
+ if self.top_k is not None and self.top_k <= 0:
+ raise ValueError("top_k must be positive when provided")
if self.thread_pool_workers <= 0:
raise ValueError("thread_pool_workers must be positive")
if self.worker_timeout <= 0:
raise ValueError("worker_timeout must be positive")
+ if self.resident_anchor_top_n < 0:
+ raise ValueError("resident_anchor_top_n must be non-negative")
return self
diff --git a/KVCOMM/llm/gpt_chat.py b/KVCOMM/llm/gpt_chat.py
index a368379..913928a 100644
--- a/KVCOMM/llm/gpt_chat.py
+++ b/KVCOMM/llm/gpt_chat.py
@@ -34,7 +34,7 @@
from KVCOMM.llm.config import KVCommConfig
from KVCOMM.llm.token_ops import *
-from KVCOMM.llm.kvcomm_engine import KVCOMMEngine, _RequestState
+from KVCOMM.llm.kvcomm_engine import KVCOMMEngine, _RequestState, _move_tensor_tree
from KVCOMM.utils.metrics import GenerationResult
from KVCOMM.utils.log import logger
@@ -48,6 +48,26 @@ def _escape_loguru_markup(text: Optional[str]) -> str:
return text.replace("<", "\\<")
+def _hf_model_load_kwargs(model_name: str) -> Tuple[torch.dtype, Optional[str]]:
+ """torch_dtype and device_map for local HF load.
+
+ Non-Llama models previously defaulted to float32 on CUDA, which roughly
+ doubles VRAM versus fp16/bf16 and commonly OOMs 7B-class weights on 24GB
+ cards during from_pretrained. device_map uses ``auto`` so visible GPUs
+ (e.g. via CUDA_VISIBLE_DEVICES) are used without hard-coding cuda:0.
+ """
+ if torch.cuda.is_available():
+ mn = model_name.lower()
+ if "llama" in mn:
+ dtype = torch.float16
+ elif torch.cuda.is_bf16_supported():
+ dtype = torch.bfloat16
+ else:
+ dtype = torch.float16
+ return dtype, "auto"
+ return torch.float32, None
+
+
_LATENCY_IO_LOCK = threading.Lock()
@@ -214,6 +234,8 @@ def __init__(self, model_name: str, prefix: str = None, config: KVCommConfig | N
self.model_name = model_name
self.config = (config or KVCommConfig.from_env()).validate()
+ self.anchor_device = torch.device("cpu")
+ self.kv_storage_device = torch.device("cpu")
self._ensure_thread_pool(self.config.thread_pool_workers)
self.kv_engine = KVCOMMEngine(self)
@@ -550,6 +572,13 @@ def _ensure_global_input_buckets(self) -> Dict[str, Dict[str, Any]]:
store.setdefault("input_drop_num", {})
return store
+ def _offload_kv_payload(self, value: Any) -> Any:
+ """Store shared KV payloads on CPU; callers materialize copies for compute."""
+ return _move_tensor_tree(value, self.kv_storage_device)
+
+ def _materialize_kv_payload(self, value: Any) -> Any:
+ return _move_tensor_tree(value, self.model.device)
+
def has_prefix_initialized(self, agent_id: str) -> bool:
"""Check if prefix KV has been initialized for an agent."""
return LLMChat._initialization.get(agent_id, False)
@@ -627,14 +656,6 @@ def update_condition_anchor(
condition_cache = generated.past_key_values
- for key_name, value in (
- ("condition", condition_cache),
- ("condition_ids", token_ids),
- ("condition_drop_num", drop_num),
- ):
- bucket = owner_memory.setdefault(key_name, {})
- bucket.setdefault(message, []).append(value)
-
anchor_store = state.anchors.setdefault(anchor_key, {})
cond_anchor_list = list(anchor_store.values())
cond_len_bucket = state.anchor_len_dict.setdefault(anchor_key, {})
@@ -654,8 +675,24 @@ def update_condition_anchor(
anchor_kv_cache_list=cond_anchor_list,
anchor_len_list=anchor_len_list,
anchor_activated_list=anchor_activated_list,
+ request_uid=request_uid,
+ ph_id=anchor_key,
+ message=message,
+ anchor_labels=list(anchor_store.keys()),
+ log_events=True,
)
+ for key_name, value in (
+ ("condition", condition_cache),
+ ("condition_ids", token_ids),
+ ("condition_drop_num", drop_num),
+ ):
+ bucket = owner_memory.setdefault(key_name, {})
+ if key_name.endswith("_drop_num"):
+ bucket.setdefault(message, []).append(value)
+ else:
+ bucket.setdefault(message, []).append(self._offload_kv_payload(value))
+
cond_flag_bucket = state.anchor_dict.setdefault(anchor_key, {})
cond_flag_bucket[message] = prob
@@ -771,13 +808,6 @@ def update_input_anchor(
)
input_cache = output.past_key_values
- global_buckets = self._ensure_global_input_buckets()
- global_buckets["input"].setdefault(message, []).append(
- input_cache.copy().slice_(start=0, end=token_ids["input_ids"].shape[-1])
- )
- global_buckets["input_ids"].setdefault(message, []).append(token_ids)
- global_buckets["input_drop_num"].setdefault(message, []).append(drop_num)
-
anchor_store = state.anchors.setdefault(anchor_namespace, {})
input_anchor_list = list(anchor_store.values())
uq_len_bucket = state.anchor_len_dict.setdefault(anchor_namespace, {})
@@ -798,11 +828,25 @@ def update_input_anchor(
anchor_len_list=anchor_len_list,
anchor_activated_list=anchor_activated_list,
test_time=test_time,
+ request_uid=request_uid,
+ ph_id=anchor_namespace,
+ message=message,
+ anchor_labels=list(anchor_store.keys()),
+ log_events=True,
)
logger.opt(colors=True).debug(
f"Anchor prediction for input '{safe_message}': {prob}"
)
+ global_buckets = self._ensure_global_input_buckets()
+ global_buckets["input"].setdefault(message, []).append(
+ self._offload_kv_payload(
+ input_cache.copy().slice_(start=0, end=token_ids["input_ids"].shape[-1])
+ )
+ )
+ global_buckets["input_ids"].setdefault(message, []).append(self._offload_kv_payload(token_ids))
+ global_buckets["input_drop_num"].setdefault(message, []).append(drop_num)
+
state.anchor_dict.setdefault(anchor_namespace, {})[message] = prob
global_bucket = state.global_anchor_info.setdefault(anchor_namespace, {})
if not prob:
@@ -836,6 +880,7 @@ async def generate_for_agent(
) -> GenerationResult:
"""Generate a response using the requested strategy with sensible fallbacks."""
latency_target = output_dir or kwargs.get("output_dir")
+ self.kv_engine.configure_anchor_event_logging(latency_target)
if preferred_mode == "dense_prefill":
mode = "dense_prefill"
elif self.has_active_anchor(request_uid, message):
@@ -918,8 +963,8 @@ async def prepare_prefix_kv_segments(self, node_id: str, prefix: str, user_promp
if type_ == "text":
seg_kv = base_kv.slice(start=s, end=e)
- segment_kv_list.append(seg_kv)
- token_id_list.append(token_id)
+ segment_kv_list.append(self._offload_kv_payload(seg_kv))
+ token_id_list.append(self._offload_kv_payload(token_id))
self._shared_kv_cache_memory[node_id]["prefix"] = LLMChat._shared_kv_cache_memory[node_id]["prefix"] = segment_kv_list
self._shared_kv_cache_memory[node_id]["placeholder_info"] = LLMChat._shared_kv_cache_memory[node_id]["placeholder_info"] = placeholder_info
self._shared_kv_cache_memory[node_id]["token_ids"] = LLMChat._shared_kv_cache_memory[node_id]["token_ids"] = token_id_list
@@ -931,11 +976,12 @@ def _initialize_shared_resources(self):
with LLMChat._model_lock:
if LLMChat._shared_model is None:
LLMChat._shared_tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+ load_dtype, device_map = _hf_model_load_kwargs(self.model_name)
LLMChat._shared_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
- torch_dtype=torch.float16 if 'llama' in self.model_name else torch.float32,
+ torch_dtype=load_dtype,
low_cpu_mem_usage=True,
- device_map="cuda:0",
+ device_map=device_map,
trust_remote_code=True
)
logger.info("Model {} loaded and shared across instances.", self.model_name)
@@ -1194,8 +1240,14 @@ async def agen_kvcomm(
message = messages
prefix_store = self._shared_kv_cache_memory[self.node_id]
- prefix_kv_list: List[DynamicCache] = prefix_store.get("prefix", [])
- prefix_token_ids: List[Dict[str, torch.Tensor]] = prefix_store.get("token_ids", [])
+ prefix_kv_list: List[DynamicCache] = [
+ self._materialize_kv_payload(cache)
+ for cache in prefix_store.get("prefix", [])
+ ]
+ prefix_token_ids: List[Dict[str, torch.Tensor]] = [
+ self._materialize_kv_payload(token_ids)
+ for token_ids in prefix_store.get("token_ids", [])
+ ]
placeholder_info_map = prefix_store.get("placeholder_info")
if not prefix_kv_list:
raise RuntimeError(
@@ -1370,12 +1422,14 @@ async def agen_kvcomm(
]
anchor_active_list: List[int] = list(anchor_info_bucket.values())
- resp.setdefault(message, []).append(response_kv_cache)
+ resp.setdefault(message, []).append(self._offload_kv_payload(response_kv_cache))
resp_ids.setdefault(message, []).append(
- {
- "input_ids": response_tokens,
- "attention_mask": response_mask,
- }
+ self._offload_kv_payload(
+ {
+ "input_ids": response_tokens,
+ "attention_mask": response_mask,
+ }
+ )
)
resp_drop.setdefault(message, []).append(0)
@@ -1390,6 +1444,11 @@ async def agen_kvcomm(
anchor_kv_cache_list=response_anchor_list,
anchor_len_list=anchor_len_list,
anchor_activated_list=anchor_active_list,
+ request_uid=request_uid,
+ ph_id=current_key,
+ message=message,
+ anchor_labels=list(anchor_bucket.keys()),
+ log_events=True,
)
safe_message = _escape_loguru_markup(message)
logger.opt(colors=True).debug(
@@ -1502,8 +1561,14 @@ async def agen_kvcomm_time_test(
message = messages
prefix_store = self._shared_kv_cache_memory[self.node_id]
- prefix_kv_list: List[DynamicCache] = prefix_store.get("prefix", [])
- prefix_token_ids: List[Dict[str, torch.Tensor]] = prefix_store.get("token_ids", [])
+ prefix_kv_list: List[DynamicCache] = [
+ self._materialize_kv_payload(cache)
+ for cache in prefix_store.get("prefix", [])
+ ]
+ prefix_token_ids: List[Dict[str, torch.Tensor]] = [
+ self._materialize_kv_payload(token_ids)
+ for token_ids in prefix_store.get("token_ids", [])
+ ]
placeholder_info_map = prefix_store.get("placeholder_info")
if not prefix_kv_list:
raise RuntimeError(
@@ -1649,6 +1714,26 @@ async def agen_kvcomm_time_test(
ttft_value = kvcomm_ttft_value
else:
ttft_value = dense_prefill_ttft
+ if mode == "dense_prefill":
+ base_cache = merged_prefix_kv
+ real_cache = full_kv_cache.slice(start=0, end=prefix_token_length)
+ real_placeholder_cache, real_prefix_cache = real_cache.split_cache_by_placeholders(
+ placeholder_indices
+ )
+ base_placeholder_cache, base_prefix_cache = base_cache.split_cache_by_placeholders(
+ placeholder_indices
+ )
+ self.kv_engine.set_anchor(
+ request_uid,
+ message,
+ ph_id_list,
+ real_placeholder_cache,
+ real_prefix_cache,
+ base_placeholder_cache,
+ base_prefix_cache,
+ max_anchor_num=max_anchor_num,
+ window_length=window_length,
+ )
response_kv_cache = full_kv_cache.slice_(start=prefix_token_length)
response_kv_cache = self.kv_engine.apply_rotary_pos_emb(
response_kv_cache,
@@ -1676,12 +1761,14 @@ async def agen_kvcomm_time_test(
]
anchor_active_list: List[int] = list(anchor_info_bucket.values())
- resp.setdefault(message, []).append(response_kv_cache)
+ resp.setdefault(message, []).append(self._offload_kv_payload(response_kv_cache))
resp_ids.setdefault(message, []).append(
- {
- "input_ids": response_tokens,
- "attention_mask": response_mask,
- }
+ self._offload_kv_payload(
+ {
+ "input_ids": response_tokens,
+ "attention_mask": response_mask,
+ }
+ )
)
resp_drop.setdefault(message, []).append(0)
@@ -1697,6 +1784,11 @@ async def agen_kvcomm_time_test(
anchor_len_list=anchor_len_list,
anchor_activated_list=anchor_active_list,
test_time=True,
+ request_uid=request_uid,
+ ph_id=current_key,
+ message=message,
+ anchor_labels=list(anchor_bucket.keys()),
+ log_events=True,
)
safe_message = _escape_loguru_markup(message)
logger.opt(colors=True).debug(
diff --git a/KVCOMM/llm/kvcomm_engine.py b/KVCOMM/llm/kvcomm_engine.py
index 4678620..92edcc3 100644
--- a/KVCOMM/llm/kvcomm_engine.py
+++ b/KVCOMM/llm/kvcomm_engine.py
@@ -9,9 +9,11 @@
from __future__ import annotations
import copy
+import csv
import threading
from collections.abc import MutableMapping
from collections.abc import Sequence
+from pathlib import Path
from time import perf_counter
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -407,6 +409,21 @@ def _to_device(cache: DynamicCache, device: Union[str, torch.device]) -> Dynamic
return cache
+def _move_tensor_tree(value: Any, device: Union[str, torch.device]) -> Any:
+ """Recursively move tensors/cache payloads to the target device."""
+ if isinstance(value, torch.Tensor):
+ return value.detach().to(device)
+ if isinstance(value, DynamicCache):
+ return value.copy().to(device)
+ if isinstance(value, dict):
+ return {key: _move_tensor_tree(item, device) for key, item in value.items()}
+ if isinstance(value, list):
+ return [_move_tensor_tree(item, device) for item in value]
+ if isinstance(value, tuple):
+ return tuple(_move_tensor_tree(item, device) for item in value)
+ return copy.deepcopy(value) if isinstance(value, set) else value
+
+
def _split_cache_by_placeholders(
cache: DynamicCache,
placeholder_dict: Dict[str, Tuple[int, int]],
@@ -535,6 +552,21 @@ def _clone_default(value: Any) -> Any:
return copy.copy(value) if hasattr(value, "__copy__") else value
+def _scoped_copy(value: Any) -> Any:
+ """Deep-copy request state while sharing explicitly resident GPU anchors."""
+ if isinstance(value, dict):
+ if value.get("_kvcomm_resident_anchor") is True:
+ return value
+ return {key: _scoped_copy(item) for key, item in value.items()}
+ if isinstance(value, list):
+ return [_scoped_copy(item) for item in value]
+ if isinstance(value, tuple):
+ return tuple(_scoped_copy(item) for item in value)
+ if isinstance(value, set):
+ return {_scoped_copy(item) for item in value}
+ return copy.deepcopy(value)
+
+
class _ScopedDict(MutableMapping):
"""Request-scoped view over a shared dictionary with deferred commits."""
@@ -546,7 +578,7 @@ def _ensure_local(self, key: str) -> None:
if key in self._local:
return
if key in self._base:
- self._local[key] = copy.deepcopy(self._base[key])
+ self._local[key] = _scoped_copy(self._base[key])
def __getitem__(self, key: str) -> Any:
if key in self._local:
@@ -555,7 +587,7 @@ def __getitem__(self, key: str) -> Any:
raise KeyError(key)
return value
if key in self._base:
- value = copy.deepcopy(self._base[key])
+ value = _scoped_copy(self._base[key])
self._local[key] = value
if value is _DELETED:
raise KeyError(key)
@@ -605,7 +637,7 @@ def setdefault(self, key: str, default: Any = None):
return new_value
return value
if key in self._base:
- value = copy.deepcopy(self._base[key])
+ value = _scoped_copy(self._base[key])
self._local[key] = value
return value
new_value = _clone_default(default)
@@ -686,14 +718,196 @@ class KVCOMMEngine:
_request_states: Dict[str, _RequestState] = {}
_active_requests: set[str] = set()
_staged_commits: List[_RequestState] = []
+ _anchor_event_lock = threading.Lock()
+ _anchor_event_step = 0
+ _anchor_event_csv_path: Optional[Path] = None
+ _anchor_lifecycle_step = 0
+ _anchor_lifecycle_csv_path: Optional[Path] = None
+ _resident_anchor_keys: set[Tuple[str, str]] = set()
+ _resident_anchor_source: Optional[Tuple[str, int]] = None
def __init__(self, llm: "LLMChat"):
self.llm = llm
self._warning_prefix = "[KVCOMMEngine]"
+ self.configure_resident_anchors_from_config()
def _log_warning(self, message: str) -> None:
logger.opt(colors=True).warning("{} {}", self._warning_prefix, message)
+ @classmethod
+ def configure_anchor_event_logging(cls, output_dir: Optional[Union[str, Path]]) -> None:
+ """Configure output csv path for KVReuse anchor diagnostics."""
+ if output_dir is None:
+ return
+ out_dir = Path(output_dir).expanduser()
+ out_dir.mkdir(parents=True, exist_ok=True)
+ with cls._anchor_event_lock:
+ cls._anchor_event_csv_path = out_dir / "kvreuse_anchor_events.csv"
+ cls._anchor_lifecycle_csv_path = out_dir / "kvreuse_anchor_lifecycle.csv"
+
+ @classmethod
+ def _next_anchor_event_step(cls) -> int:
+ with cls._anchor_event_lock:
+ cls._anchor_event_step += 1
+ return cls._anchor_event_step
+
+ @classmethod
+ def _write_anchor_event_rows(cls, rows: List[Dict[str, Any]]) -> None:
+ csv_path = cls._anchor_event_csv_path
+ if csv_path is None or not rows:
+ return
+ fieldnames = [
+ "step",
+ "request_uid",
+ "ph_id",
+ "message",
+ "anchor_msg",
+ "is_candidate",
+ "is_selected",
+ "sim_score",
+ "weight",
+ "skip_reason",
+ "placeholder_len",
+ "available_anchor_num",
+ "selected_anchor_num",
+ ]
+ with cls._anchor_event_lock:
+ file_exists = csv_path.exists()
+ with csv_path.open("a", encoding="utf-8", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ if not file_exists:
+ writer.writeheader()
+ for row in rows:
+ writer.writerow({k: row.get(k, "") for k in fieldnames})
+
+ @classmethod
+ def _next_anchor_lifecycle_step(cls) -> int:
+ with cls._anchor_event_lock:
+ cls._anchor_lifecycle_step += 1
+ return cls._anchor_lifecycle_step
+
+ @classmethod
+ def _write_anchor_lifecycle_row(cls, row: Dict[str, Any]) -> None:
+ csv_path = cls._anchor_lifecycle_csv_path
+ if csv_path is None:
+ return
+ fieldnames = [
+ "lifecycle_step",
+ "event",
+ "request_uid",
+ "ph_id",
+ "message",
+ "node_id",
+ "role",
+ "frequency",
+ "placeholder_len",
+ "accumulate_len",
+ "anchor_count_before",
+ "anchor_count_after",
+ "reason",
+ ]
+ with cls._anchor_event_lock:
+ file_exists = csv_path.exists()
+ with csv_path.open("a", encoding="utf-8", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ if not file_exists:
+ writer.writeheader()
+ writer.writerow({k: row.get(k, "") for k in fieldnames})
+
+ def _anchor_storage_device(self) -> torch.device:
+ device = getattr(self.llm, "anchor_device", None)
+ if device is None:
+ return torch.device("cpu")
+ return torch.device(device)
+
+ def _compute_device(self) -> torch.device:
+ return torch.device(self.llm.model.device)
+
+ def _materialize_anchor_entries(
+ self,
+ anchor_entries: List[Dict[str, Any]],
+ *,
+ device: Optional[Union[str, torch.device]] = None,
+ ) -> List[Dict[str, Any]]:
+ target_device = torch.device(device) if device is not None else self._compute_device()
+ return [_move_tensor_tree(entry, target_device) for entry in anchor_entries]
+
+ def _offload_anchor_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
+ return _move_tensor_tree(entry, self._anchor_storage_device())
+
+ @classmethod
+ def configure_resident_anchors(
+ cls,
+ summary_path: Optional[Union[str, Path]],
+ top_n: int,
+ ) -> None:
+ """Load the fixed hot-anchor set that should stay resident on GPU."""
+ if not summary_path or top_n <= 0:
+ cls._resident_anchor_keys = set()
+ cls._resident_anchor_source = None
+ return
+
+ path = Path(summary_path).expanduser()
+ source = (str(path), int(top_n))
+ if cls._resident_anchor_source == source:
+ return
+
+ rows: List[Dict[str, Any]] = []
+ try:
+ with path.open("r", encoding="utf-8", newline="") as handle:
+ reader = csv.DictReader(handle)
+ for row in reader:
+ rows.append(row)
+ except OSError as exc:
+ logger.opt(colors=True).warning(
+ "[KVCOMMEngine] Failed to load resident anchor summary {}: {}",
+ path,
+ exc,
+ )
+ cls._resident_anchor_keys = set()
+ cls._resident_anchor_source = source
+ return
+
+ def _selected_count(row: Dict[str, Any]) -> int:
+ try:
+ return int(float(row.get("selected_count", 0) or 0))
+ except (TypeError, ValueError):
+ return 0
+
+ rows.sort(key=_selected_count, reverse=True)
+ selected = rows[:top_n]
+ cls._resident_anchor_keys = {
+ (row.get("ph_id", ""), row.get("anchor_msg", ""))
+ for row in selected
+ if row.get("ph_id") and row.get("anchor_msg")
+ }
+ cls._resident_anchor_source = source
+ logger.opt(colors=True).info(
+ "[KVCOMMEngine] Loaded {} resident hot anchors from {} (top_n={})",
+ len(cls._resident_anchor_keys),
+ path,
+ top_n,
+ )
+
+ def configure_resident_anchors_from_config(self) -> None:
+ config = getattr(self.llm, "config", None)
+ if config is None:
+ return
+ self.configure_resident_anchors(
+ getattr(config, "resident_anchor_summary", None),
+ int(getattr(config, "resident_anchor_top_n", 0) or 0),
+ )
+
+ def _is_resident_anchor(self, ph_id: str, message: str) -> bool:
+ return (ph_id, message) in self._resident_anchor_keys
+
+ def _store_anchor_entry(self, ph_id: str, message: str, entry: Dict[str, Any]) -> Dict[str, Any]:
+ if self._is_resident_anchor(ph_id, message):
+ resident_entry = _move_tensor_tree(entry, self._compute_device())
+ resident_entry["_kvcomm_resident_anchor"] = True
+ return resident_entry
+ return self._offload_anchor_entry(entry)
+
@staticmethod
def _stack_cache_tensors(cache: DynamicCache) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.stack(cache.key_cache), torch.stack(cache.value_cache)
@@ -788,7 +1002,7 @@ def _get_cached_anchor_weights(
return None
if entry.get("anchor_signature") != signature:
return None
- return entry
+ return _move_tensor_tree(entry, self._compute_device())
def _set_cached_anchor_weights(
self,
@@ -800,7 +1014,7 @@ def _set_cached_anchor_weights(
"""Store computed anchor weights for reuse within the same request."""
state = self.resolve_request_state(request_uid)
bucket = state.weight_dict.setdefault(ph_id, {})
- bucket[message] = entry
+ bucket[message] = _move_tensor_tree(entry, self._anchor_storage_device())
@staticmethod
def _select_anchor_indices(anchor_list: List[Dict[str, Any]], placeholder_len: int) -> List[int]:
@@ -884,6 +1098,7 @@ def offset_kv_cache_pair(
real_key_embedding, real_value_embedding = self._stack_cache_tensors(base_placeholder_cache)
anchor_signature = self.anchor_signature(anchor_list)
+ anchor_list_on_device = self._materialize_anchor_entries(anchor_list)
cache_entry = self._get_cached_anchor_weights(
request_uid,
@@ -893,7 +1108,7 @@ def offset_kv_cache_pair(
)
if cache_entry is None:
- anchor_index = self._select_anchor_indices(anchor_list, placeholder_len)
+ anchor_index = self._select_anchor_indices(anchor_list_on_device, placeholder_len)
if not anchor_index:
if anchor_list:
self._log_warning(
@@ -902,7 +1117,7 @@ def offset_kv_cache_pair(
return base_placeholder_cache.copy(), base_prefix_cache.copy()
cache_entry = self._compute_anchor_weight_entry(
- anchor_list,
+ anchor_list_on_device,
anchor_index,
real_key_embedding,
real_value_embedding,
@@ -925,14 +1140,14 @@ def offset_kv_cache_pair(
weights_value_for_placeholder = cache_entry["weights_value_for_placeholder"]
prefix_key_delta_stack = torch.stack(
- [anchor_list[i][f"{self.llm.node_id}_pf_key_delta"] for i in anchor_index]
+ [anchor_list_on_device[i][f"{self.llm.node_id}_pf_key_delta"] for i in anchor_index]
)
layer_total_delta_key_for_prefix = (
weights_key_for_prefix * prefix_key_delta_stack
).sum(0)
prefix_value_delta_stack = torch.stack(
- [anchor_list[i][f"{self.llm.node_id}_pf_value_delta"] for i in anchor_index]
+ [anchor_list_on_device[i][f"{self.llm.node_id}_pf_value_delta"] for i in anchor_index]
)
layer_total_value_delta_for_prefix = (
weights_value_for_prefix * prefix_value_delta_stack
@@ -940,7 +1155,7 @@ def offset_kv_cache_pair(
placeholder_key_delta_stack = torch.stack(
[
- anchor_list[i][f"{self.llm.node_id}_ph_key_delta"][..., :placeholder_len, :]
+ anchor_list_on_device[i][f"{self.llm.node_id}_ph_key_delta"][..., :placeholder_len, :]
for i in anchor_index
]
)
@@ -950,7 +1165,7 @@ def offset_kv_cache_pair(
placeholder_value_delta_stack = torch.stack(
[
- anchor_list[i][f"{self.llm.node_id}_ph_value_delta"][..., :placeholder_len, :]
+ anchor_list_on_device[i][f"{self.llm.node_id}_ph_value_delta"][..., :placeholder_len, :]
for i in anchor_index
]
)
@@ -995,11 +1210,86 @@ def predict_as_anchor(
anchor_len_list: List[Tuple[int, int]],
anchor_activated_list: List[int],
top_p: float = 0.9,
+ top_k: Optional[int] = None,
entropy_eps: float = 1e-40,
test_time: bool = False,
+ request_uid: str = "",
+ ph_id: str = "",
+ message: str = "",
+ anchor_labels: Optional[List[str]] = None,
+ log_events: bool = False,
) -> Tuple[bool, List[int]]:
+ step = self._next_anchor_event_step() if log_events else -1
+ label_list = anchor_labels if anchor_labels is not None else [""] * len(anchor_kv_cache_list)
+
+ def _emit_events(
+ *,
+ available: List[int],
+ selected: List[int],
+ sim_values: Optional[torch.Tensor],
+ skip_reason: str = "none",
+ placeholder_len: int = 0,
+ ) -> None:
+ if not log_events:
+ return
+ selected_set = set(selected)
+ rows: List[Dict[str, Any]] = []
+ if not available:
+ rows.append(
+ {
+ "step": step,
+ "request_uid": request_uid,
+ "ph_id": ph_id,
+ "message": message,
+ "anchor_msg": "__none__",
+ "is_candidate": 0,
+ "is_selected": 0,
+ "sim_score": "",
+ "weight": "",
+ "skip_reason": skip_reason,
+ "placeholder_len": placeholder_len,
+ "available_anchor_num": 0,
+ "selected_anchor_num": 0,
+ }
+ )
+ self._write_anchor_event_rows(rows)
+ return
+
+ for idx, anchor_idx in enumerate(available):
+ sim_val = (
+ float(sim_values[idx].detach().cpu().item())
+ if sim_values is not None and idx < sim_values.shape[0]
+ else ""
+ )
+ row_skip_reason = skip_reason if skip_reason != "none" else "none"
+ rows.append(
+ {
+ "step": step,
+ "request_uid": request_uid,
+ "ph_id": ph_id,
+ "message": message,
+ "anchor_msg": label_list[anchor_idx] if anchor_idx < len(label_list) else "",
+ "is_candidate": 1,
+ "is_selected": 1 if anchor_idx in selected_set else 0,
+ "sim_score": sim_val,
+ "weight": sim_val if anchor_idx in selected_set else "",
+ "skip_reason": row_skip_reason,
+ "placeholder_len": placeholder_len,
+ "available_anchor_num": len(available),
+ "selected_anchor_num": len(selected),
+ }
+ )
+ self._write_anchor_event_rows(rows)
+
if len(anchor_kv_cache_list) in [0, 1]:
+ _emit_events(
+ available=[],
+ selected=[],
+ sim_values=None,
+ skip_reason="insufficient_anchors",
+ )
return True, anchor_activated_list
+ anchor_kv_cache_list = self._materialize_anchor_entries(anchor_kv_cache_list)
if test_time:
torch.cuda.synchronize()
@@ -1012,6 +1302,13 @@ def predict_as_anchor(
"The length of anchor_len_list is not equal to the length of anchor_available, "
f"with {len(anchor_len_list)} and {len(anchor_available)}."
)
+ _emit_events(
+ available=[],
+ selected=[],
+ sim_values=None,
+ skip_reason="len_mismatch",
+ placeholder_len=k,
+ )
return True, anchor_activated_list
if len(anchor_available) > 1:
@@ -1028,6 +1325,13 @@ def predict_as_anchor(
f"Entropy {entropy:.4f} exceeds threshold {threshold * torch.log2(torch.tensor(sim.shape[0])):.4f}, "
"skip activating anchors."
)
+ _emit_events(
+ available=anchor_available,
+ selected=[],
+ sim_values=sim,
+ skip_reason="entropy_skip",
+ placeholder_len=k,
+ )
if test_time:
torch.cuda.synchronize()
end_time = perf_counter()
@@ -1036,10 +1340,19 @@ def predict_as_anchor(
)
return True, anchor_activated_list
sorted_sim, sorted_indices = torch.sort(sim, descending=True)
- cumulative_sum = torch.cumsum(sorted_sim, dim=0)
- cutoff_index_candidates = (cumulative_sum < top_p).nonzero(as_tuple=True)[0]
- cutoff_index = cutoff_index_candidates[-1] if len(cutoff_index_candidates) > 0 else len(sorted_sim) - 1
- selected_indices = sorted_indices[:cutoff_index + 1]
+ effective_top_k = self.llm.config.top_k if top_k is None else top_k
+ if effective_top_k is not None:
+ selected_indices = sorted_indices[: min(effective_top_k, len(sorted_indices))].tolist()
+ else:
+ cumulative_sum = torch.cumsum(sorted_sim, dim=0)
+ cutoff_index_candidates = (cumulative_sum < top_p).nonzero(as_tuple=True)[0]
+ cutoff_index = cutoff_index_candidates[-1] if len(cutoff_index_candidates) > 0 else len(sorted_sim) - 1
+ selected_indices = sorted_indices[:cutoff_index + 1].tolist()
+ selected_anchor_indices = [
+ anchor_available[i]
+ for i in selected_indices
+ if i < len(anchor_available)
+ ]
for i in selected_indices:
if anchor_available[i] >= len(anchor_activated_list):
self._log_warning(
@@ -1049,6 +1362,13 @@ def predict_as_anchor(
)
continue
anchor_activated_list[anchor_available[i]] += 1
+ _emit_events(
+ available=anchor_available,
+ selected=selected_anchor_indices,
+ sim_values=sim,
+ skip_reason="none",
+ placeholder_len=k,
+ )
if test_time:
torch.cuda.synchronize()
end_time = perf_counter()
@@ -1057,6 +1377,13 @@ def predict_as_anchor(
)
return False, anchor_activated_list
logger.opt(colors=True).debug("No available anchors to activate.")
+ _emit_events(
+ available=anchor_available,
+ selected=[],
+ sim_values=None,
+ skip_reason="no_available",
+ placeholder_len=k,
+ )
return True, anchor_activated_list
def update_anchor(self, request_uid: str, ph_id: str, window_length: int = 5) -> None:
@@ -1066,15 +1393,35 @@ def update_anchor(self, request_uid: str, ph_id: str, window_length: int = 5) ->
state = self.resolve_request_state(request_uid)
anchor_store = state.anchors.setdefault(ph_id, {})
anchor_info_dict = state.anchor_info_dict.setdefault(ph_id, {})
- info_list = list(anchor_info_dict.values())[:window_length]
- if not info_list:
+ info_items = list(anchor_info_dict.items())[:window_length]
+ removable_items = [
+ item
+ for item in info_items
+ if not self._is_resident_anchor(ph_id, item[0])
+ ]
+ if not removable_items:
return
- min_idx = info_list.index(min(info_list))
- message = list(anchor_info_dict.keys())[min_idx]
+ message, _ = min(removable_items, key=lambda item: item[1])
+ anchor_count_before = len(anchor_store)
anchor_store.pop(message, None)
state.anchor_len_dict.setdefault(ph_id, {}).pop(message, None)
freq = anchor_info_dict.pop(message, None)
state.global_anchor_info.setdefault(ph_id, {}).pop(message, None)
+ self._write_anchor_lifecycle_row(
+ {
+ "lifecycle_step": self._next_anchor_lifecycle_step(),
+ "event": "remove",
+ "request_uid": request_uid,
+ "ph_id": ph_id,
+ "message": message,
+ "node_id": self.llm.node_id,
+ "role": self.llm.role,
+ "frequency": freq,
+ "anchor_count_before": anchor_count_before,
+ "anchor_count_after": len(anchor_store),
+ "reason": f"low_frequency_in_oldest_window_{window_length}",
+ }
+ )
self._log_warning(
f"Removed anchor for message '{message}' in {self.llm.node_id} ({self.llm.role}) due to low frequency: {freq}"
)
@@ -1143,12 +1490,13 @@ def _make_anchor(i, ph_id, real_ph, base_ph, real_pf, base_pf):
if i not in anchor_dict:
accumulate_len += placeholder_len
continue
- entry = anchor_dict[i]
+ entry = self._store_anchor_entry(ph_id_list[i], message, anchor_dict[i])
over_store = len(anchor_store.setdefault(ph_id_list[i], {})) > max_anchor_num
if over_store:
self.update_anchor(request_uid, ph_id_list[i], window_length)
if message not in anchor_store[ph_id_list[i]]:
+ anchor_count_before = len(anchor_store[ph_id_list[i]])
anchor_store[ph_id_list[i]][message] = entry
info_bucket = state.anchor_info_dict.setdefault(ph_id_list[i], {})
info_bucket[message] = 0
@@ -1163,6 +1511,23 @@ def _make_anchor(i, ph_id, real_ph, base_ph, real_pf, base_pf):
message,
[0, placeholder_len],
)
+ self._write_anchor_lifecycle_row(
+ {
+ "lifecycle_step": self._next_anchor_lifecycle_step(),
+ "event": "add",
+ "request_uid": request_uid,
+ "ph_id": ph_id_list[i],
+ "message": message,
+ "node_id": self.llm.node_id,
+ "role": self.llm.role,
+ "frequency": 0,
+ "placeholder_len": placeholder_len,
+ "accumulate_len": accumulate_len,
+ "anchor_count_before": anchor_count_before,
+ "anchor_count_after": len(anchor_store[ph_id_list[i]]),
+ "reason": "new_anchor",
+ }
+ )
else:
anchor_store[ph_id_list[i]][message].update(entry)
accumulate_len += placeholder_len
@@ -1222,8 +1587,8 @@ def fetch_shared_cache(
if "user_question" in ph_id:
return (
- shared_memory["input"][message][-1],
- shared_memory["input_ids"][message][-1],
+ _move_tensor_tree(shared_memory["input"][message][-1], self._compute_device()),
+ _move_tensor_tree(shared_memory["input_ids"][message][-1], self._compute_device()),
shared_memory["input_drop_num"][message][-1],
)
@@ -1254,7 +1619,11 @@ def _get_slot(bucket_key: str):
f"fetch_shared_cache: placeholder {ph_id} for message='{message}' not found."
)
- return ph_cache, ph_cache_ids, drop_num
+ return (
+ _move_tensor_tree(ph_cache, self._compute_device()),
+ _move_tensor_tree(ph_cache_ids, self._compute_device()),
+ drop_num,
+ )
@staticmethod
def trim_token_ids(ids_dict: Dict[str, torch.Tensor], drop_num: int) -> Dict[str, torch.Tensor]:
diff --git a/KVCOMM/utils/gpu_debug.py b/KVCOMM/utils/gpu_debug.py
new file mode 100644
index 0000000..4b929c3
--- /dev/null
+++ b/KVCOMM/utils/gpu_debug.py
@@ -0,0 +1,371 @@
+from __future__ import annotations
+
+import csv
+import gc
+import time
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Set
+
+import torch
+from transformers.cache_utils import DynamicCache
+
+DEFAULT_TARGET_DEVICE = 0
+
+
+def _tensor_nbytes(tensor: torch.Tensor) -> int:
+ return tensor.numel() * tensor.element_size()
+
+
+def _is_scoped_dict_like(obj: Any) -> bool:
+ return hasattr(obj, "_base") and hasattr(obj, "_local")
+
+
+def _walk_cuda_tensors(
+ obj: Any,
+ *,
+ path: str,
+ target_device: Optional[int],
+ seen_objects: Set[int],
+ seen_tensors: Set[int],
+ out: List[Dict[str, Any]],
+) -> None:
+ obj_id = id(obj)
+ if obj_id in seen_objects:
+ return
+ seen_objects.add(obj_id)
+
+ if isinstance(obj, torch.Tensor):
+ if obj.is_cuda and (target_device is None or obj.device.index == target_device):
+ ptr = obj.data_ptr()
+ if ptr not in seen_tensors:
+ seen_tensors.add(ptr)
+ out.append(
+ {
+ "path": path,
+ "shape": tuple(obj.shape),
+ "dtype": str(obj.dtype),
+ "device": str(obj.device),
+ "mb": _tensor_nbytes(obj) / 1024**2,
+ }
+ )
+ return
+
+ if isinstance(obj, DynamicCache):
+ if hasattr(obj, "layers"):
+ for idx, layer in enumerate(getattr(obj, "layers", [])):
+ _walk_cuda_tensors(
+ getattr(layer, "keys", None),
+ path=f"{path}.layers[{idx}].keys",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+ _walk_cuda_tensors(
+ getattr(layer, "values", None),
+ path=f"{path}.layers[{idx}].values",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+ else:
+ for idx, key in enumerate(getattr(obj, "key_cache", [])):
+ _walk_cuda_tensors(
+ key,
+ path=f"{path}.key_cache[{idx}]",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+ for idx, value in enumerate(getattr(obj, "value_cache", [])):
+ _walk_cuda_tensors(
+ value,
+ path=f"{path}.value_cache[{idx}]",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+ return
+
+ if isinstance(obj, dict):
+ for key, value in obj.items():
+ _walk_cuda_tensors(
+ value,
+ path=f"{path}[{repr(key)}]",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+ return
+
+ if isinstance(obj, (list, tuple, set)):
+ for idx, value in enumerate(obj):
+ _walk_cuda_tensors(
+ value,
+ path=f"{path}[{idx}]",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+ return
+
+ if _is_scoped_dict_like(obj):
+ _walk_cuda_tensors(
+ getattr(obj, "_base", None),
+ path=f"{path}._base",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+ _walk_cuda_tensors(
+ getattr(obj, "_local", None),
+ path=f"{path}._local",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+ return
+
+ if hasattr(obj, "__dict__"):
+ for key, value in vars(obj).items():
+ _walk_cuda_tensors(
+ value,
+ path=f"{path}.{key}",
+ target_device=target_device,
+ seen_objects=seen_objects,
+ seen_tensors=seen_tensors,
+ out=out,
+ )
+
+
+def collect_cuda_tensor_info(
+ name: str,
+ obj: Any,
+ *,
+ target_device: Optional[int] = DEFAULT_TARGET_DEVICE,
+) -> Dict[str, Any]:
+ entries: List[Dict[str, Any]] = []
+ _walk_cuda_tensors(
+ obj,
+ path=name,
+ target_device=target_device,
+ seen_objects=set(),
+ seen_tensors=set(),
+ out=entries,
+ )
+ entries.sort(key=lambda item: item["mb"], reverse=True)
+ total_mb = sum(item["mb"] for item in entries)
+ return {
+ "name": name,
+ "count": len(entries),
+ "total_mb": total_mb,
+ "entries": entries,
+ }
+
+
+def current_cuda_memory(device: Optional[int] = DEFAULT_TARGET_DEVICE) -> Dict[str, float]:
+ if not torch.cuda.is_available():
+ return {"allocated_mb": 0.0, "reserved_mb": 0.0, "max_allocated_mb": 0.0}
+ device_arg = torch.device(f"cuda:{device}") if device is not None else None
+ return {
+ "allocated_mb": torch.cuda.memory_allocated(device_arg) / 1024**2,
+ "reserved_mb": torch.cuda.memory_reserved(device_arg) / 1024**2,
+ "max_allocated_mb": torch.cuda.max_memory_allocated(device_arg) / 1024**2,
+ }
+
+
+def reset_cuda_peak_memory(device: Optional[int] = DEFAULT_TARGET_DEVICE) -> None:
+ """Reset CUDA peak stats so later max_allocated reflects a local interval."""
+ if not torch.cuda.is_available():
+ return
+ device_arg = torch.device(f"cuda:{device}") if device is not None else None
+ torch.cuda.reset_peak_memory_stats(device_arg)
+
+
+def summarize_kvcomm_cuda_state(
+ *,
+ topk: int = 20,
+ include_gc_tensors: bool = False,
+ target_device: Optional[int] = DEFAULT_TARGET_DEVICE,
+) -> Dict[str, Any]:
+ from KVCOMM.llm.gpt_chat import LLMChat
+ from KVCOMM.llm.kvcomm_engine import KVCOMMEngine
+
+ sections = [
+ ("anchors", KVCOMMEngine.anchors),
+ ("weight_dict", KVCOMMEngine.weight_dict),
+ ("request_states", KVCOMMEngine._request_states),
+ ("staged_commits", KVCOMMEngine._staged_commits),
+ ("shared_kv_cache_memory", LLMChat._shared_kv_cache_memory),
+ ]
+
+ report = {
+ "memory": current_cuda_memory(target_device),
+ "target_device": target_device,
+ "sections": [],
+ }
+ for name, obj in sections:
+ section = collect_cuda_tensor_info(name, obj, target_device=target_device)
+ section["entries"] = section["entries"][:topk]
+ report["sections"].append(section)
+
+ if include_gc_tensors:
+ gc_entries: List[Dict[str, Any]] = []
+ seen_ptrs: Set[int] = set()
+ for obj in gc.get_objects():
+ try:
+ if (
+ isinstance(obj, torch.Tensor)
+ and obj.is_cuda
+ and (target_device is None or obj.device.index == target_device)
+ ):
+ ptr = obj.data_ptr()
+ if ptr in seen_ptrs:
+ continue
+ seen_ptrs.add(ptr)
+ gc_entries.append(
+ {
+ "shape": tuple(obj.shape),
+ "dtype": str(obj.dtype),
+ "device": str(obj.device),
+ "mb": _tensor_nbytes(obj) / 1024**2,
+ }
+ )
+ except Exception:
+ continue
+ gc_entries.sort(key=lambda item: item["mb"], reverse=True)
+ report["gc_cuda_tensors"] = gc_entries[:topk]
+
+ return report
+
+
+def append_kvcomm_cuda_state_csv(
+ output_dir: Optional[str | Path],
+ *,
+ tag: str,
+ batch_index: Optional[int] = None,
+ phase: Optional[str] = None,
+ topk: int = 0,
+ include_gc_tensors: bool = False,
+ target_device: Optional[int] = DEFAULT_TARGET_DEVICE,
+ filename: str = "CudaMemory.csv",
+) -> None:
+ """Append a structured CUDA memory snapshot for experiment comparisons."""
+ if output_dir is None:
+ return
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ report = summarize_kvcomm_cuda_state(
+ topk=topk,
+ include_gc_tensors=include_gc_tensors,
+ target_device=target_device,
+ )
+ memory = report["memory"]
+ row: Dict[str, Any] = {
+ "timestamp": time.time(),
+ "tag": tag,
+ "batch_index": batch_index if batch_index is not None else "",
+ "phase": phase or "",
+ "target_device": "" if target_device is None else target_device,
+ "allocated_mb": memory["allocated_mb"],
+ "reserved_mb": memory["reserved_mb"],
+ "max_allocated_mb": memory["max_allocated_mb"],
+ }
+ for section in report["sections"]:
+ name = section["name"]
+ row[f"{name}_cuda_tensor_count"] = section["count"]
+ row[f"{name}_cuda_tensor_mb"] = section["total_mb"]
+
+ csv_path = Path(output_dir).expanduser() / filename
+ csv_path.parent.mkdir(parents=True, exist_ok=True)
+ fieldnames = [
+ "timestamp",
+ "tag",
+ "batch_index",
+ "phase",
+ "target_device",
+ "allocated_mb",
+ "reserved_mb",
+ "max_allocated_mb",
+ "anchors_cuda_tensor_count",
+ "anchors_cuda_tensor_mb",
+ "weight_dict_cuda_tensor_count",
+ "weight_dict_cuda_tensor_mb",
+ "request_states_cuda_tensor_count",
+ "request_states_cuda_tensor_mb",
+ "staged_commits_cuda_tensor_count",
+ "staged_commits_cuda_tensor_mb",
+ "shared_kv_cache_memory_cuda_tensor_count",
+ "shared_kv_cache_memory_cuda_tensor_mb",
+ ]
+ file_exists = csv_path.exists()
+ if file_exists:
+ with csv_path.open("r", encoding="utf-8", newline="") as handle:
+ reader = csv.DictReader(handle)
+ existing_fieldnames = reader.fieldnames or []
+ existing_rows = list(reader)
+ if existing_fieldnames and existing_fieldnames != fieldnames:
+ with csv_path.open("w", encoding="utf-8", newline="") as handle:
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
+ writer.writeheader()
+ for existing_row in existing_rows:
+ writer.writerow({key: existing_row.get(key, "") for key in fieldnames})
+ with csv_path.open("a", encoding="utf-8", newline="") as handle:
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
+ if not file_exists:
+ writer.writeheader()
+ writer.writerow({key: row.get(key, "") for key in fieldnames})
+
+
+def print_kvcomm_cuda_state(
+ *,
+ tag: Optional[str] = None,
+ topk: int = 20,
+ include_gc_tensors: bool = False,
+ target_device: Optional[int] = DEFAULT_TARGET_DEVICE,
+) -> None:
+ report = summarize_kvcomm_cuda_state(
+ topk=topk,
+ include_gc_tensors=include_gc_tensors,
+ target_device=target_device,
+ )
+ prefix = f"[{tag}] " if tag else ""
+ memory = report["memory"]
+ device_label = (
+ f"cuda:{report['target_device']}"
+ if report["target_device"] is not None
+ else "all_cuda_devices"
+ )
+ print(
+ f"{prefix}CUDA memory ({device_label}): "
+ f"allocated={memory['allocated_mb']:.1f}MB "
+ f"reserved={memory['reserved_mb']:.1f}MB "
+ f"max_allocated={memory['max_allocated_mb']:.1f}MB"
+ )
+ for section in report["sections"]:
+ print(
+ f"{prefix}{section['name']}: "
+ f"cuda_tensors={section['count']} total={section['total_mb']:.1f}MB"
+ )
+ for item in section["entries"]:
+ print(
+ f"{prefix} - {item['path']} "
+ f"shape={item['shape']} dtype={item['dtype']} "
+ f"device={item['device']} size={item['mb']:.1f}MB"
+ )
+ if include_gc_tensors:
+ gc_entries = report.get("gc_cuda_tensors", [])
+ print(f"{prefix}gc_cuda_tensors: {len(gc_entries)} shown")
+ for item in gc_entries:
+ print(
+ f"{prefix} - shape={item['shape']} dtype={item['dtype']} "
+ f"device={item['device']} size={item['mb']:.1f}MB"
+ )
diff --git a/datasets/MMLU/download.py b/datasets/MMLU/download.py
index c043db3..1f4cd7a 100644
--- a/datasets/MMLU/download.py
+++ b/datasets/MMLU/download.py
@@ -1,26 +1,166 @@
import os
+import shutil
+from typing import Iterator
+
import requests
import tarfile
+from requests.adapters import HTTPAdapter
+from tqdm import tqdm
+from urllib3.util.retry import Retry
+
from KVCOMM.utils.log import logger
+# (connect timeout, read timeout between chunks) — large tar needs a generous read timeout
+_REQUEST_TIMEOUT = (30, 600)
+
+# Hugging Face hosts the same Hendrycks archive; CDN is often much faster than Berkeley direct.
+_HF_MMLU_TAR = (
+ "https://huggingface.co/datasets/Stevross/mmlu/resolve/main/data.tar"
+)
+_BERKELEY_MMLU_TAR = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
+
+_HF_HUB_REPO = "Stevross/mmlu"
+_HF_HUB_FILENAME = "data.tar"
+
+
+def _tar_url_candidates() -> Iterator[str]:
+ """URLs to try, in order. Set MMLU_DATA_TAR_URL to force a single mirror."""
+ custom = os.environ.get("MMLU_DATA_TAR_URL", "").strip()
+ if custom:
+ yield custom
+ return
+ yield _HF_MMLU_TAR
+ yield _BERKELEY_MMLU_TAR
+
+
+def _request_headers() -> dict:
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
+ if token:
+ return {"Authorization": f"Bearer {token}"}
+ return {}
+
+
+def _requests_session() -> requests.Session:
+ """Session with retries for flaky TLS / CDN resets."""
+ s = requests.Session()
+ retries = Retry(
+ total=8,
+ connect=8,
+ read=8,
+ backoff_factor=1.2,
+ status_forcelist=(429, 500, 502, 503, 504),
+ allowed_methods=frozenset(["GET", "HEAD"]),
+ )
+ adapter = HTTPAdapter(max_retries=retries)
+ s.mount("https://", adapter)
+ s.mount("http://", adapter)
+ return s
+
+
+def _download_via_hf_hub(tar_path: str) -> bool:
+ """Use huggingface_hub (respects HF_ENDPOINT, token, hub cache; more robust than raw GET)."""
+ try:
+ from huggingface_hub import hf_hub_download
+ except ImportError:
+ return False
+ try:
+ cached = hf_hub_download(
+ repo_id=_HF_HUB_REPO,
+ filename=_HF_HUB_FILENAME,
+ repo_type="dataset",
+ )
+ if os.path.abspath(cached) == os.path.abspath(tar_path):
+ return True
+ shutil.copy2(cached, tar_path)
+ return True
+ except Exception as exc:
+ logger.warning("huggingface_hub download failed: {}", exc)
+ return False
+
+
+def _stream_download(url: str, tar_path: str) -> None:
+ expected_size = None
+ session = _requests_session()
+ with session.get(
+ url,
+ allow_redirects=True,
+ stream=True,
+ timeout=_REQUEST_TIMEOUT,
+ headers=_request_headers(),
+ ) as r:
+ r.raise_for_status()
+ cl = r.headers.get("Content-Length")
+ expected_size = int(cl) if cl is not None else None
+ chunk_size = 1024 * 1024
+ with open(tar_path, "wb") as f, tqdm(
+ desc="data.tar",
+ total=expected_size,
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ miniters=1,
+ ) as pbar:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(len(data))
+ if expected_size is not None:
+ got = os.path.getsize(tar_path)
+ if got != expected_size:
+ raise OSError(
+ f"Incomplete download: expected {expected_size} bytes, got {got}"
+ )
+
def download():
this_file_path = os.path.split(__file__)[0]
tar_path = os.path.join(this_file_path, "data.tar")
if not os.path.exists(tar_path):
- url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
- logger.info("Downloading {}", url)
- r = requests.get(url, allow_redirects=True)
- with open(tar_path, 'wb') as f:
- f.write(r.content)
- logger.info("Saved to {}", tar_path)
+ custom_url = os.environ.get("MMLU_DATA_TAR_URL", "").strip()
+ ok = False
+ if not custom_url:
+ logger.info(
+ "Trying huggingface_hub (repo={}, file={}; set HF_ENDPOINT for mirrors)",
+ _HF_HUB_REPO,
+ _HF_HUB_FILENAME,
+ )
+ ok = _download_via_hf_hub(tar_path)
+ if ok:
+ logger.info("Saved to {}", tar_path)
+ else:
+ last_error: Exception | None = None
+ for url in _tar_url_candidates():
+ logger.info("Downloading {}", url)
+ try:
+ _stream_download(url, tar_path)
+ last_error = None
+ break
+ except Exception as exc:
+ last_error = exc
+ if os.path.exists(tar_path):
+ os.unlink(tar_path)
+ logger.warning("Download from {} failed: {}", url, exc)
+ if last_error is not None:
+ raise last_error
+ logger.info("Saved to {}", tar_path)
data_path = os.path.join(this_file_path, "data")
if not os.path.exists(data_path):
- tar = tarfile.open(tar_path)
- tar.extractall(this_file_path)
- tar.close()
+ try:
+ with tarfile.open(tar_path) as tar:
+ tar.extractall(this_file_path)
+ except (tarfile.ReadError, tarfile.TarError, EOFError) as exc:
+ if os.path.isdir(data_path):
+ shutil.rmtree(data_path, ignore_errors=True)
+ if os.path.exists(tar_path):
+ os.unlink(tar_path)
+ logger.error(
+ "data.tar is corrupt or truncated ({}). Removed archive and partial "
+ "'data/'; run again to re-download.",
+ exc,
+ )
+ raise
logger.info("Saved to {}", data_path)
diff --git a/experiments/analyze_kvreuse_anchor.py b/experiments/analyze_kvreuse_anchor.py
new file mode 100644
index 0000000..81e2b2f
--- /dev/null
+++ b/experiments/analyze_kvreuse_anchor.py
@@ -0,0 +1,391 @@
+import argparse
+import csv
+from collections import Counter, defaultdict
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import matplotlib.pyplot as plt
+
+
+def _to_int(value: str, default: int = 0) -> int:
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ return default
+
+
+def _to_float(value: str, default: float = 0.0) -> float:
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ return default
+
+
+def load_events(csv_path: Path) -> List[Dict[str, str]]:
+ with csv_path.open("r", encoding="utf-8", newline="") as f:
+ reader = csv.DictReader(f)
+ rows = list(reader)
+ if not rows:
+ raise ValueError(f"Input CSV is empty: {csv_path}")
+ return rows
+
+
+def build_summary(
+ rows: List[Dict[str, str]],
+) -> Tuple[List[Dict[str, str]], Counter, int]:
+ grouped: Dict[Tuple[str, str], Dict[str, float]] = {}
+ skip_reason_counter: Counter = Counter()
+ max_step = 0
+
+ for row in rows:
+ ph_id = row.get("ph_id", "")
+ anchor_msg = row.get("anchor_msg", "")
+ key = (ph_id, anchor_msg)
+ if key not in grouped:
+ grouped[key] = {
+ "candidate_count": 0.0,
+ "selected_count": 0.0,
+ "last_selected_step": -1.0,
+ "weight_sum": 0.0,
+ "weight_n": 0.0,
+ }
+
+ bucket = grouped[key]
+ is_candidate = _to_int(row.get("is_candidate", "0"))
+ is_selected = _to_int(row.get("is_selected", "0"))
+ step = _to_int(row.get("step", "0"))
+ weight = _to_float(row.get("weight", ""))
+ skip_reason = row.get("skip_reason", "none") or "none"
+
+ bucket["candidate_count"] += is_candidate
+ bucket["selected_count"] += is_selected
+ if is_selected:
+ bucket["last_selected_step"] = max(bucket["last_selected_step"], step)
+ bucket["weight_sum"] += weight
+ bucket["weight_n"] += 1
+
+ if skip_reason != "none":
+ skip_reason_counter[skip_reason] += 1
+ max_step = max(max_step, step)
+
+ summary_rows: List[Dict[str, str]] = []
+ for (ph_id, anchor_msg), bucket in grouped.items():
+ candidate_count = int(bucket["candidate_count"])
+ selected_count = int(bucket["selected_count"])
+ last_selected_step = int(bucket["last_selected_step"])
+ selected_rate = (selected_count / candidate_count) if candidate_count > 0 else 0.0
+ idle_steps = (max_step - last_selected_step) if last_selected_step >= 0 else max_step
+ avg_weight = (bucket["weight_sum"] / bucket["weight_n"]) if bucket["weight_n"] > 0 else 0.0
+ summary_rows.append(
+ {
+ "ph_id": ph_id,
+ "anchor_msg": anchor_msg,
+ "candidate_count": str(candidate_count),
+ "selected_count": str(selected_count),
+ "selected_rate": f"{selected_rate:.6f}",
+ "last_selected_step": str(last_selected_step),
+ "idle_steps": str(int(idle_steps)),
+ "avg_weight": f"{avg_weight:.6f}",
+ }
+ )
+
+ summary_rows.sort(key=lambda x: int(x["selected_count"]), reverse=True)
+ return summary_rows, skip_reason_counter, max_step
+
+
+def save_summary_csv(rows: List[Dict[str, str]], output_path: Path) -> None:
+ fields = [
+ "ph_id",
+ "anchor_msg",
+ "candidate_count",
+ "selected_count",
+ "selected_rate",
+ "last_selected_step",
+ "idle_steps",
+ "avg_weight",
+ ]
+ with output_path.open("w", encoding="utf-8", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fields)
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def plot_topn_bar(summary_rows: List[Dict[str, str]], output_path: Path, top_n: int) -> None:
+ top = summary_rows[:top_n]
+ labels = [f"{r['ph_id']}#{i}" for i, r in enumerate(top)]
+ values = [int(r["selected_count"]) for r in top]
+
+ plt.figure(figsize=(12, 5))
+ plt.bar(range(len(values)), values)
+ plt.xticks(range(len(values)), labels, rotation=45, ha="right")
+ plt.ylabel("selected_count")
+ plt.title(f"KVReuse Anchor Hotness Top-{top_n}")
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=180)
+ plt.close()
+
+
+def plot_lorenz(summary_rows: List[Dict[str, str]], output_path: Path) -> None:
+ counts = [int(r["selected_count"]) for r in summary_rows]
+ if not counts:
+ counts = [0]
+ counts = sorted(counts)
+ total = sum(counts)
+
+ x, y = [0.0], [0.0]
+ running = 0
+ n = len(counts)
+ for i, c in enumerate(counts, start=1):
+ running += c
+ x.append(i / n)
+ y.append((running / total) if total > 0 else 0.0)
+
+ plt.figure(figsize=(6, 6))
+ plt.plot(x, y, label="Observed")
+ plt.plot([0, 1], [0, 1], "--", label="Uniform")
+ plt.xlabel("Cumulative share of anchors")
+ plt.ylabel("Cumulative share of selections")
+ plt.title("KVReuse Anchor Concentration (Lorenz Curve)")
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=180)
+ plt.close()
+
+
+def plot_rolling_concentration(
+ rows: List[Dict[str, str]],
+ output_path: Path,
+ rolling_window: int,
+) -> None:
+ selected_rows = [r for r in rows if _to_int(r.get("is_selected", "0")) == 1]
+ if not selected_rows:
+ plt.figure(figsize=(12, 5))
+ plt.title("No selected anchors found")
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=180)
+ plt.close()
+ return
+
+ selected_rows.sort(key=lambda r: _to_int(r.get("step", "0")))
+ step_anchor_counts: Dict[int, Counter] = defaultdict(Counter)
+ for r in selected_rows:
+ step = _to_int(r.get("step", "0"))
+ anchor_key = f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}"
+ step_anchor_counts[step][anchor_key] += 1
+
+ steps = sorted(step_anchor_counts.keys())
+ top1_series, top5_series = [], []
+ for idx, step in enumerate(steps):
+ left = max(0, idx - rolling_window + 1)
+ window_steps = steps[left : idx + 1]
+ merged = Counter()
+ for s in window_steps:
+ merged.update(step_anchor_counts[s])
+ total = sum(merged.values())
+ if total == 0:
+ top1_series.append(0.0)
+ top5_series.append(0.0)
+ continue
+ freqs = sorted(merged.values(), reverse=True)
+ top1_series.append(freqs[0] / total)
+ top5_series.append(sum(freqs[:5]) / total)
+
+ plt.figure(figsize=(12, 5))
+ plt.plot(steps, top1_series, label="Top1 share")
+ plt.plot(steps, top5_series, label="Top5 share")
+ plt.ylim(0.0, 1.0)
+ plt.xlabel("step")
+ plt.ylabel("share in rolling window")
+ plt.title(f"KVReuse Rolling Anchor Concentration (window={rolling_window})")
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=180)
+ plt.close()
+
+
+def plot_global_hot_anchor_trends(
+ rows: List[Dict[str, str]],
+ summary_rows: List[Dict[str, str]],
+ output_path: Path,
+ top_k: int,
+) -> None:
+ top_anchors = [
+ f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}"
+ for r in summary_rows
+ if int(r["selected_count"]) > 0
+ ][:top_k]
+
+ if not top_anchors:
+ plt.figure(figsize=(12, 5))
+ plt.title("No selected anchors found")
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=180)
+ plt.close()
+ return
+
+ selected_rows = [r for r in rows if _to_int(r.get("is_selected", "0")) == 1]
+ selected_rows.sort(key=lambda r: _to_int(r.get("step", "0")))
+
+ steps = sorted({_to_int(r.get("step", "0")) for r in selected_rows})
+ per_step_counts: Dict[str, Counter] = {anchor: Counter() for anchor in top_anchors}
+ for r in selected_rows:
+ anchor_key = f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}"
+ if anchor_key in per_step_counts:
+ per_step_counts[anchor_key][_to_int(r.get("step", "0"))] += 1
+
+ plt.figure(figsize=(12, 5))
+ for rank, anchor_key in enumerate(top_anchors, start=1):
+ running = 0
+ cumulative = []
+ for step in steps:
+ running += per_step_counts[anchor_key][step]
+ cumulative.append(running)
+ ph_id, _ = anchor_key.split("||", 1)
+ plt.plot(steps, cumulative, label=f"Top{rank}: {ph_id}")
+
+ plt.xlabel("step")
+ plt.ylabel("cumulative selected_count")
+ plt.title(f"KVReuse Global Hot Anchor Trends (top={len(top_anchors)})")
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=180)
+ plt.close()
+
+
+def plot_global_hot_anchor_rolling_share(
+ rows: List[Dict[str, str]],
+ summary_rows: List[Dict[str, str]],
+ output_path: Path,
+ top_k: int,
+ rolling_window: int,
+) -> None:
+ top_anchors = [
+ f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}"
+ for r in summary_rows
+ if int(r["selected_count"]) > 0
+ ][:top_k]
+
+ selected_rows = [r for r in rows if _to_int(r.get("is_selected", "0")) == 1]
+ if not selected_rows or not top_anchors:
+ plt.figure(figsize=(12, 5))
+ plt.title("No selected anchors found")
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=180)
+ plt.close()
+ return
+
+ selected_rows.sort(key=lambda r: _to_int(r.get("step", "0")))
+ step_anchor_counts: Dict[int, Counter] = defaultdict(Counter)
+ for r in selected_rows:
+ step = _to_int(r.get("step", "0"))
+ anchor_key = f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}"
+ step_anchor_counts[step][anchor_key] += 1
+
+ steps = sorted(step_anchor_counts.keys())
+ topk_total_series: List[float] = []
+
+ for idx, _step in enumerate(steps):
+ left = max(0, idx - rolling_window + 1)
+ window_steps = steps[left : idx + 1]
+ merged = Counter()
+ for s in window_steps:
+ merged.update(step_anchor_counts[s])
+
+ total = sum(merged.values())
+ if total == 0:
+ topk_total_series.append(0.0)
+ continue
+
+ topk_total = sum(merged[anchor] for anchor in top_anchors)
+ topk_total_series.append(topk_total / total)
+
+ plt.figure(figsize=(12, 5))
+ plt.plot(steps, topk_total_series, linewidth=2.0, label=f"Top{len(top_anchors)} total")
+
+ plt.ylim(0.0, 1.0)
+ plt.xlabel("step")
+ plt.ylabel("share in rolling window")
+ plt.title(
+ f"KVReuse Global Hot Anchor Rolling Share "
+ f"(top={len(top_anchors)}, window={rolling_window})"
+ )
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=180)
+ plt.close()
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Analyze KVReuse anchor hot/cold behavior.")
+ parser.add_argument(
+ "--events-csv",
+ type=str,
+ required=True,
+ help="Path to kvreuse_anchor_events.csv",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="runs/anchor_analysis",
+ help="Directory to save CSV and figures.",
+ )
+ parser.add_argument(
+ "--top-n",
+ type=int,
+ default=20,
+ help="Top-N anchors shown in bar chart.",
+ )
+ parser.add_argument(
+ "--rolling-window",
+ type=int,
+ default=200,
+ help="Rolling window size (in step bins) for concentration chart.",
+ )
+ parser.add_argument(
+ "--trend-top-k",
+ type=int,
+ default=5,
+ help="Global Top-K selected anchors shown in cumulative trend chart.",
+ )
+ args = parser.parse_args()
+
+ events_csv = Path(args.events_csv).expanduser().resolve()
+ output_dir = Path(args.output_dir).expanduser().resolve()
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ rows = load_events(events_csv)
+ summary_rows, skip_reasons, max_step = build_summary(rows)
+
+ summary_csv = output_dir / "kvreuse_anchor_summary.csv"
+ save_summary_csv(summary_rows, summary_csv)
+
+ plot_topn_bar(summary_rows, output_dir / "kvreuse_topn_hotness.png", args.top_n)
+ plot_lorenz(summary_rows, output_dir / "kvreuse_lorenz_curve.png")
+ plot_rolling_concentration(rows, output_dir / "kvreuse_rolling_concentration.png", args.rolling_window)
+ plot_global_hot_anchor_trends(
+ rows,
+ summary_rows,
+ output_dir / "kvreuse_global_hot_anchor_trends.png",
+ args.trend_top_k,
+ )
+ plot_global_hot_anchor_rolling_share(
+ rows,
+ summary_rows,
+ output_dir / "kvreuse_global_hot_anchor_rolling_share.png",
+ args.trend_top_k,
+ args.rolling_window,
+ )
+
+ print(f"[done] summary csv: {summary_csv}")
+ print(f"[done] figures: {output_dir / 'kvreuse_topn_hotness.png'}")
+ print(f"[done] figures: {output_dir / 'kvreuse_lorenz_curve.png'}")
+ print(f"[done] figures: {output_dir / 'kvreuse_rolling_concentration.png'}")
+ print(f"[done] figures: {output_dir / 'kvreuse_global_hot_anchor_trends.png'}")
+ print(f"[done] figures: {output_dir / 'kvreuse_global_hot_anchor_rolling_share.png'}")
+ print(f"[info] max_step={max_step}, anchors={len(summary_rows)}")
+ if skip_reasons:
+ print(f"[info] skip_reasons={dict(skip_reasons)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/experiments/run_gsm8k.py b/experiments/run_gsm8k.py
index 72d0a4b..7dce421 100644
--- a/experiments/run_gsm8k.py
+++ b/experiments/run_gsm8k.py
@@ -17,6 +17,7 @@
from KVCOMM.llm.config import KVCommConfig
from KVCOMM.tools.reader.readers import JSONLReader
from KVCOMM.utils.globals import Time
+from KVCOMM.utils.gpu_debug import print_kvcomm_cuda_state
from KVCOMM.utils.log import configure_logging, logger
from KVCOMM.utils.metrics import metrics_recorder
from datasets.gsm8k_dataset import (
@@ -109,7 +110,8 @@ async def main():
current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
Time.instance().value = current_time
- result_file = output_dir / f"{args.domain}_{args.llm_name}_{current_time}.json"
+ llm_tag = args.llm_name.replace("/", "__")
+ result_file = output_dir / f"{args.domain}_{llm_tag}_{current_time}.json"
latency_target = str(output_dir)
agent_names = [name for name, num in zip(args.agent_names, args.agent_nums) for _ in range(num)]
@@ -226,6 +228,7 @@ async def main():
)
logger.opt(colors=True).info(f"[ACCURACY] {accuracy:.4f}")
metrics_recorder.log_cumulative(batch_index=i_batch)
+ print_kvcomm_cuda_state(tag=f"batch_{i_batch}_end", topk=10)
def get_kwargs(
diff --git a/experiments/run_humaneval.py b/experiments/run_humaneval.py
index 647fc50..532555f 100644
--- a/experiments/run_humaneval.py
+++ b/experiments/run_humaneval.py
@@ -18,6 +18,11 @@
from KVCOMM.tools.coding.python_executor import PyExecutor
from KVCOMM.tools.reader.readers import JSONLReader
from KVCOMM.utils.globals import Time
+from KVCOMM.utils.gpu_debug import (
+ append_kvcomm_cuda_state_csv,
+ print_kvcomm_cuda_state,
+ reset_cuda_peak_memory,
+)
from KVCOMM.utils.log import configure_logging, logger
from KVCOMM.utils.metrics import metrics_recorder
@@ -84,6 +89,27 @@ def parse_args():
parser.add_argument("--kv-window-size", type=int, default=None, help="Window size for key-value memory update.")
parser.add_argument("--kv-thread-workers", type=int, default=None, help="Number of thread workers for key-value memory processing.")
parser.add_argument("--kv-worker-timeout", type=float, default=None, help="Timeout for key-value memory workers processing.")
+ parser.add_argument(
+ "--resident-anchor-summary",
+ type=str,
+ default=None,
+ help="Path to kvreuse_anchor_summary.csv; top selected anchors from this file stay resident on GPU.",
+ )
+ parser.add_argument(
+ "--resident-anchor-top-n",
+ type=int,
+ default=None,
+ help="Number of hot anchors from --resident-anchor-summary to keep resident on GPU.",
+ )
+ parser.add_argument(
+ "--test-time",
+ "--test_time",
+ nargs="?",
+ const=True,
+ default=False,
+ type=lambda value: str(value).lower() in {"1", "true", "yes", "y", "on"},
+ help="Run kv_reuse calls through the dense-vs-KV timing comparison path.",
+ )
args = parser.parse_args()
if len(args.agent_names) != len(args.agent_nums):
@@ -100,7 +126,8 @@ async def main():
current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
Time.instance().value = current_time
- result_file = output_dir / f"{args.domain}_{args.llm_name}_{current_time}.json"
+ llm_tag = args.llm_name.replace("/", "__")
+ result_file = output_dir / f"{args.domain}_{llm_tag}_{current_time}.json"
latency_target = str(output_dir)
agent_names = [name for name, num in zip(args.agent_names, args.agent_nums) for _ in range(num)]
@@ -114,10 +141,24 @@ async def main():
window_size=args.kv_window_size,
thread_pool_workers=args.kv_thread_workers,
worker_timeout=args.kv_worker_timeout,
+ resident_anchor_summary=args.resident_anchor_summary,
+ resident_anchor_top_n=args.resident_anchor_top_n,
)
else:
kv_config = KVCommConfig.from_env()
+ logger.info(
+ "HumanEval run config: execution_mode={} threshold={} max_anchor_num={} window_size={} top_k={} test_time={} resident_anchor_summary={} resident_anchor_top_n={} anchor_device=cpu",
+ args.execution_mode,
+ kv_config.threshold,
+ kv_config.max_anchor_num,
+ kv_config.window_size,
+ kv_config.top_k,
+ args.test_time,
+ kv_config.resident_anchor_summary,
+ kv_config.resident_anchor_top_n,
+ )
+
graph = Graph(
domain=args.domain,
llm_name=args.llm_name,
@@ -132,6 +173,13 @@ async def main():
for i_batch in range(num_batches):
logger.opt(colors=True).info(f"[BATCH] {i_batch} {'-' * 40}")
+ reset_cuda_peak_memory()
+ append_kvcomm_cuda_state_csv(
+ output_dir,
+ tag=f"batch_{i_batch}_start",
+ batch_index=i_batch,
+ phase="batch_start",
+ )
start_ts = time.time()
current_batch = dataloader(dataset, args.batch_size, i_batch)
if not current_batch:
@@ -152,7 +200,8 @@ async def main():
if args.execution_mode == "allow_kv_reuse":
mode_kwargs = {
"prefix": args.prefix,
- "output_dir": latency_target
+ "output_dir": latency_target,
+ "test_time": args.test_time,
}
tasks.append(
@@ -210,6 +259,13 @@ async def main():
)
logger.opt(colors=True).info(f"[ACCURACY] {accuracy:.4f}")
metrics_recorder.log_cumulative(batch_index=i_batch)
+ append_kvcomm_cuda_state_csv(
+ output_dir,
+ tag=f"batch_{i_batch}_end",
+ batch_index=i_batch,
+ phase="batch_end",
+ )
+ print_kvcomm_cuda_state(tag=f"batch_{i_batch}_end", topk=10)
def get_kwargs(
mode: Union[
diff --git a/experiments/run_mmlu.py b/experiments/run_mmlu.py
index 3452d4f..f1007c8 100644
--- a/experiments/run_mmlu.py
+++ b/experiments/run_mmlu.py
@@ -64,6 +64,7 @@ def parse_args():
parser.add_argument("--kv-threshold", type=float, default=None, help="Threshold for key-value memory usage.")
parser.add_argument("--kv-max-anchor-num", type=int, default=20, help="Maximum number of anchors for key-value memory.")
parser.add_argument("--kv-window-size", type=int, default=None, help="Window size for key-value memory update.")
+ parser.add_argument("--kv-top-k", type=int, default=None, help="Top-k anchors to activate; overrides top-p selection when set.")
parser.add_argument("--kv-thread-workers", type=int, default=None, help="Number of thread workers for key-value memory processing.")
parser.add_argument("--kv-worker-timeout", type=float, default=None, help="Timeout for key-value memory workers processing.")
@@ -85,6 +86,7 @@ async def main():
threshold=args.kv_threshold,
max_anchor_num=args.kv_max_anchor_num,
window_size=args.kv_window_size,
+ top_k=args.kv_top_k,
thread_pool_workers=args.kv_thread_workers,
worker_timeout=args.kv_worker_timeout,
)
@@ -119,7 +121,9 @@ async def main():
**eval_kwargs,
)
logger.opt(colors=True).info("[MMLU SCORE] {:.4f}", score)
- result_file = output_dir / f"{args.domain}_{args.llm_name}_{timestamp}.json"
+ # HF model ids contain "/"; must not appear raw in filenames (path separator).
+ llm_tag = args.llm_name.replace("/", "__")
+ result_file = output_dir / f"{args.domain}_{llm_tag}_{timestamp}.json"
result_file.touch(exist_ok=True)
payload = {
"score": score,