diff --git a/KVCOMM/llm/config.py b/KVCOMM/llm/config.py index 1bf6ed8..2e4345c 100644 --- a/KVCOMM/llm/config.py +++ b/KVCOMM/llm/config.py @@ -18,8 +18,11 @@ class KVCommConfig: threshold: float = 0.3 max_anchor_num: int = 20 window_size: int = 5 + top_k: int | None = None thread_pool_workers: int = 8 worker_timeout: float = 30.0 + resident_anchor_summary: str | None = None + resident_anchor_top_n: int = 0 @classmethod def from_env(cls) -> "KVCommConfig": @@ -28,8 +31,15 @@ def from_env(cls) -> "KVCommConfig": threshold=float(os.environ.get("THRESHOLD", cls.threshold)), max_anchor_num=int(os.environ.get("MAX_ANCHOR_NUM", cls.max_anchor_num)), window_size=int(os.environ.get("WINDOW_SIZE", cls.window_size)), + top_k=( + int(os.environ.get("KVCOMM_TOP_K", os.environ.get("TOP_K"))) + if os.environ.get("KVCOMM_TOP_K", os.environ.get("TOP_K")) is not None + else cls.top_k + ), thread_pool_workers=int(os.environ.get("KVCOMM_THREAD_WORKERS", cls.thread_pool_workers)), worker_timeout=float(os.environ.get("KVCOMM_WORKER_TIMEOUT", cls.worker_timeout)), + resident_anchor_summary=os.environ.get("KVCOMM_RESIDENT_ANCHOR_SUMMARY"), + resident_anchor_top_n=int(os.environ.get("KVCOMM_RESIDENT_ANCHOR_TOP_N", cls.resident_anchor_top_n)), ).validate() def apply_overrides(self, **overrides: Any) -> "KVCommConfig": @@ -43,8 +53,12 @@ def apply_overrides(self, **overrides: Any) -> "KVCommConfig": def validate(self) -> "KVCommConfig": """Validate value ranges and return self.""" + if self.top_k is not None and self.top_k <= 0: + raise ValueError("top_k must be positive when provided") if self.thread_pool_workers <= 0: raise ValueError("thread_pool_workers must be positive") if self.worker_timeout <= 0: raise ValueError("worker_timeout must be positive") + if self.resident_anchor_top_n < 0: + raise ValueError("resident_anchor_top_n must be non-negative") return self diff --git a/KVCOMM/llm/gpt_chat.py b/KVCOMM/llm/gpt_chat.py index a368379..913928a 100644 --- a/KVCOMM/llm/gpt_chat.py +++ b/KVCOMM/llm/gpt_chat.py @@ -34,7 +34,7 @@ from KVCOMM.llm.config import KVCommConfig from KVCOMM.llm.token_ops import * -from KVCOMM.llm.kvcomm_engine import KVCOMMEngine, _RequestState +from KVCOMM.llm.kvcomm_engine import KVCOMMEngine, _RequestState, _move_tensor_tree from KVCOMM.utils.metrics import GenerationResult from KVCOMM.utils.log import logger @@ -48,6 +48,26 @@ def _escape_loguru_markup(text: Optional[str]) -> str: return text.replace("<", "\\<") +def _hf_model_load_kwargs(model_name: str) -> Tuple[torch.dtype, Optional[str]]: + """torch_dtype and device_map for local HF load. + + Non-Llama models previously defaulted to float32 on CUDA, which roughly + doubles VRAM versus fp16/bf16 and commonly OOMs 7B-class weights on 24GB + cards during from_pretrained. device_map uses ``auto`` so visible GPUs + (e.g. via CUDA_VISIBLE_DEVICES) are used without hard-coding cuda:0. + """ + if torch.cuda.is_available(): + mn = model_name.lower() + if "llama" in mn: + dtype = torch.float16 + elif torch.cuda.is_bf16_supported(): + dtype = torch.bfloat16 + else: + dtype = torch.float16 + return dtype, "auto" + return torch.float32, None + + _LATENCY_IO_LOCK = threading.Lock() @@ -214,6 +234,8 @@ def __init__(self, model_name: str, prefix: str = None, config: KVCommConfig | N self.model_name = model_name self.config = (config or KVCommConfig.from_env()).validate() + self.anchor_device = torch.device("cpu") + self.kv_storage_device = torch.device("cpu") self._ensure_thread_pool(self.config.thread_pool_workers) self.kv_engine = KVCOMMEngine(self) @@ -550,6 +572,13 @@ def _ensure_global_input_buckets(self) -> Dict[str, Dict[str, Any]]: store.setdefault("input_drop_num", {}) return store + def _offload_kv_payload(self, value: Any) -> Any: + """Store shared KV payloads on CPU; callers materialize copies for compute.""" + return _move_tensor_tree(value, self.kv_storage_device) + + def _materialize_kv_payload(self, value: Any) -> Any: + return _move_tensor_tree(value, self.model.device) + def has_prefix_initialized(self, agent_id: str) -> bool: """Check if prefix KV has been initialized for an agent.""" return LLMChat._initialization.get(agent_id, False) @@ -627,14 +656,6 @@ def update_condition_anchor( condition_cache = generated.past_key_values - for key_name, value in ( - ("condition", condition_cache), - ("condition_ids", token_ids), - ("condition_drop_num", drop_num), - ): - bucket = owner_memory.setdefault(key_name, {}) - bucket.setdefault(message, []).append(value) - anchor_store = state.anchors.setdefault(anchor_key, {}) cond_anchor_list = list(anchor_store.values()) cond_len_bucket = state.anchor_len_dict.setdefault(anchor_key, {}) @@ -654,8 +675,24 @@ def update_condition_anchor( anchor_kv_cache_list=cond_anchor_list, anchor_len_list=anchor_len_list, anchor_activated_list=anchor_activated_list, + request_uid=request_uid, + ph_id=anchor_key, + message=message, + anchor_labels=list(anchor_store.keys()), + log_events=True, ) + for key_name, value in ( + ("condition", condition_cache), + ("condition_ids", token_ids), + ("condition_drop_num", drop_num), + ): + bucket = owner_memory.setdefault(key_name, {}) + if key_name.endswith("_drop_num"): + bucket.setdefault(message, []).append(value) + else: + bucket.setdefault(message, []).append(self._offload_kv_payload(value)) + cond_flag_bucket = state.anchor_dict.setdefault(anchor_key, {}) cond_flag_bucket[message] = prob @@ -771,13 +808,6 @@ def update_input_anchor( ) input_cache = output.past_key_values - global_buckets = self._ensure_global_input_buckets() - global_buckets["input"].setdefault(message, []).append( - input_cache.copy().slice_(start=0, end=token_ids["input_ids"].shape[-1]) - ) - global_buckets["input_ids"].setdefault(message, []).append(token_ids) - global_buckets["input_drop_num"].setdefault(message, []).append(drop_num) - anchor_store = state.anchors.setdefault(anchor_namespace, {}) input_anchor_list = list(anchor_store.values()) uq_len_bucket = state.anchor_len_dict.setdefault(anchor_namespace, {}) @@ -798,11 +828,25 @@ def update_input_anchor( anchor_len_list=anchor_len_list, anchor_activated_list=anchor_activated_list, test_time=test_time, + request_uid=request_uid, + ph_id=anchor_namespace, + message=message, + anchor_labels=list(anchor_store.keys()), + log_events=True, ) logger.opt(colors=True).debug( f"Anchor prediction for input '{safe_message}': {prob}" ) + global_buckets = self._ensure_global_input_buckets() + global_buckets["input"].setdefault(message, []).append( + self._offload_kv_payload( + input_cache.copy().slice_(start=0, end=token_ids["input_ids"].shape[-1]) + ) + ) + global_buckets["input_ids"].setdefault(message, []).append(self._offload_kv_payload(token_ids)) + global_buckets["input_drop_num"].setdefault(message, []).append(drop_num) + state.anchor_dict.setdefault(anchor_namespace, {})[message] = prob global_bucket = state.global_anchor_info.setdefault(anchor_namespace, {}) if not prob: @@ -836,6 +880,7 @@ async def generate_for_agent( ) -> GenerationResult: """Generate a response using the requested strategy with sensible fallbacks.""" latency_target = output_dir or kwargs.get("output_dir") + self.kv_engine.configure_anchor_event_logging(latency_target) if preferred_mode == "dense_prefill": mode = "dense_prefill" elif self.has_active_anchor(request_uid, message): @@ -918,8 +963,8 @@ async def prepare_prefix_kv_segments(self, node_id: str, prefix: str, user_promp if type_ == "text": seg_kv = base_kv.slice(start=s, end=e) - segment_kv_list.append(seg_kv) - token_id_list.append(token_id) + segment_kv_list.append(self._offload_kv_payload(seg_kv)) + token_id_list.append(self._offload_kv_payload(token_id)) self._shared_kv_cache_memory[node_id]["prefix"] = LLMChat._shared_kv_cache_memory[node_id]["prefix"] = segment_kv_list self._shared_kv_cache_memory[node_id]["placeholder_info"] = LLMChat._shared_kv_cache_memory[node_id]["placeholder_info"] = placeholder_info self._shared_kv_cache_memory[node_id]["token_ids"] = LLMChat._shared_kv_cache_memory[node_id]["token_ids"] = token_id_list @@ -931,11 +976,12 @@ def _initialize_shared_resources(self): with LLMChat._model_lock: if LLMChat._shared_model is None: LLMChat._shared_tokenizer = AutoTokenizer.from_pretrained(self.model_name) + load_dtype, device_map = _hf_model_load_kwargs(self.model_name) LLMChat._shared_model = AutoModelForCausalLM.from_pretrained( self.model_name, - torch_dtype=torch.float16 if 'llama' in self.model_name else torch.float32, + torch_dtype=load_dtype, low_cpu_mem_usage=True, - device_map="cuda:0", + device_map=device_map, trust_remote_code=True ) logger.info("Model {} loaded and shared across instances.", self.model_name) @@ -1194,8 +1240,14 @@ async def agen_kvcomm( message = messages prefix_store = self._shared_kv_cache_memory[self.node_id] - prefix_kv_list: List[DynamicCache] = prefix_store.get("prefix", []) - prefix_token_ids: List[Dict[str, torch.Tensor]] = prefix_store.get("token_ids", []) + prefix_kv_list: List[DynamicCache] = [ + self._materialize_kv_payload(cache) + for cache in prefix_store.get("prefix", []) + ] + prefix_token_ids: List[Dict[str, torch.Tensor]] = [ + self._materialize_kv_payload(token_ids) + for token_ids in prefix_store.get("token_ids", []) + ] placeholder_info_map = prefix_store.get("placeholder_info") if not prefix_kv_list: raise RuntimeError( @@ -1370,12 +1422,14 @@ async def agen_kvcomm( ] anchor_active_list: List[int] = list(anchor_info_bucket.values()) - resp.setdefault(message, []).append(response_kv_cache) + resp.setdefault(message, []).append(self._offload_kv_payload(response_kv_cache)) resp_ids.setdefault(message, []).append( - { - "input_ids": response_tokens, - "attention_mask": response_mask, - } + self._offload_kv_payload( + { + "input_ids": response_tokens, + "attention_mask": response_mask, + } + ) ) resp_drop.setdefault(message, []).append(0) @@ -1390,6 +1444,11 @@ async def agen_kvcomm( anchor_kv_cache_list=response_anchor_list, anchor_len_list=anchor_len_list, anchor_activated_list=anchor_active_list, + request_uid=request_uid, + ph_id=current_key, + message=message, + anchor_labels=list(anchor_bucket.keys()), + log_events=True, ) safe_message = _escape_loguru_markup(message) logger.opt(colors=True).debug( @@ -1502,8 +1561,14 @@ async def agen_kvcomm_time_test( message = messages prefix_store = self._shared_kv_cache_memory[self.node_id] - prefix_kv_list: List[DynamicCache] = prefix_store.get("prefix", []) - prefix_token_ids: List[Dict[str, torch.Tensor]] = prefix_store.get("token_ids", []) + prefix_kv_list: List[DynamicCache] = [ + self._materialize_kv_payload(cache) + for cache in prefix_store.get("prefix", []) + ] + prefix_token_ids: List[Dict[str, torch.Tensor]] = [ + self._materialize_kv_payload(token_ids) + for token_ids in prefix_store.get("token_ids", []) + ] placeholder_info_map = prefix_store.get("placeholder_info") if not prefix_kv_list: raise RuntimeError( @@ -1649,6 +1714,26 @@ async def agen_kvcomm_time_test( ttft_value = kvcomm_ttft_value else: ttft_value = dense_prefill_ttft + if mode == "dense_prefill": + base_cache = merged_prefix_kv + real_cache = full_kv_cache.slice(start=0, end=prefix_token_length) + real_placeholder_cache, real_prefix_cache = real_cache.split_cache_by_placeholders( + placeholder_indices + ) + base_placeholder_cache, base_prefix_cache = base_cache.split_cache_by_placeholders( + placeholder_indices + ) + self.kv_engine.set_anchor( + request_uid, + message, + ph_id_list, + real_placeholder_cache, + real_prefix_cache, + base_placeholder_cache, + base_prefix_cache, + max_anchor_num=max_anchor_num, + window_length=window_length, + ) response_kv_cache = full_kv_cache.slice_(start=prefix_token_length) response_kv_cache = self.kv_engine.apply_rotary_pos_emb( response_kv_cache, @@ -1676,12 +1761,14 @@ async def agen_kvcomm_time_test( ] anchor_active_list: List[int] = list(anchor_info_bucket.values()) - resp.setdefault(message, []).append(response_kv_cache) + resp.setdefault(message, []).append(self._offload_kv_payload(response_kv_cache)) resp_ids.setdefault(message, []).append( - { - "input_ids": response_tokens, - "attention_mask": response_mask, - } + self._offload_kv_payload( + { + "input_ids": response_tokens, + "attention_mask": response_mask, + } + ) ) resp_drop.setdefault(message, []).append(0) @@ -1697,6 +1784,11 @@ async def agen_kvcomm_time_test( anchor_len_list=anchor_len_list, anchor_activated_list=anchor_active_list, test_time=True, + request_uid=request_uid, + ph_id=current_key, + message=message, + anchor_labels=list(anchor_bucket.keys()), + log_events=True, ) safe_message = _escape_loguru_markup(message) logger.opt(colors=True).debug( diff --git a/KVCOMM/llm/kvcomm_engine.py b/KVCOMM/llm/kvcomm_engine.py index 4678620..92edcc3 100644 --- a/KVCOMM/llm/kvcomm_engine.py +++ b/KVCOMM/llm/kvcomm_engine.py @@ -9,9 +9,11 @@ from __future__ import annotations import copy +import csv import threading from collections.abc import MutableMapping from collections.abc import Sequence +from pathlib import Path from time import perf_counter from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -407,6 +409,21 @@ def _to_device(cache: DynamicCache, device: Union[str, torch.device]) -> Dynamic return cache +def _move_tensor_tree(value: Any, device: Union[str, torch.device]) -> Any: + """Recursively move tensors/cache payloads to the target device.""" + if isinstance(value, torch.Tensor): + return value.detach().to(device) + if isinstance(value, DynamicCache): + return value.copy().to(device) + if isinstance(value, dict): + return {key: _move_tensor_tree(item, device) for key, item in value.items()} + if isinstance(value, list): + return [_move_tensor_tree(item, device) for item in value] + if isinstance(value, tuple): + return tuple(_move_tensor_tree(item, device) for item in value) + return copy.deepcopy(value) if isinstance(value, set) else value + + def _split_cache_by_placeholders( cache: DynamicCache, placeholder_dict: Dict[str, Tuple[int, int]], @@ -535,6 +552,21 @@ def _clone_default(value: Any) -> Any: return copy.copy(value) if hasattr(value, "__copy__") else value +def _scoped_copy(value: Any) -> Any: + """Deep-copy request state while sharing explicitly resident GPU anchors.""" + if isinstance(value, dict): + if value.get("_kvcomm_resident_anchor") is True: + return value + return {key: _scoped_copy(item) for key, item in value.items()} + if isinstance(value, list): + return [_scoped_copy(item) for item in value] + if isinstance(value, tuple): + return tuple(_scoped_copy(item) for item in value) + if isinstance(value, set): + return {_scoped_copy(item) for item in value} + return copy.deepcopy(value) + + class _ScopedDict(MutableMapping): """Request-scoped view over a shared dictionary with deferred commits.""" @@ -546,7 +578,7 @@ def _ensure_local(self, key: str) -> None: if key in self._local: return if key in self._base: - self._local[key] = copy.deepcopy(self._base[key]) + self._local[key] = _scoped_copy(self._base[key]) def __getitem__(self, key: str) -> Any: if key in self._local: @@ -555,7 +587,7 @@ def __getitem__(self, key: str) -> Any: raise KeyError(key) return value if key in self._base: - value = copy.deepcopy(self._base[key]) + value = _scoped_copy(self._base[key]) self._local[key] = value if value is _DELETED: raise KeyError(key) @@ -605,7 +637,7 @@ def setdefault(self, key: str, default: Any = None): return new_value return value if key in self._base: - value = copy.deepcopy(self._base[key]) + value = _scoped_copy(self._base[key]) self._local[key] = value return value new_value = _clone_default(default) @@ -686,14 +718,196 @@ class KVCOMMEngine: _request_states: Dict[str, _RequestState] = {} _active_requests: set[str] = set() _staged_commits: List[_RequestState] = [] + _anchor_event_lock = threading.Lock() + _anchor_event_step = 0 + _anchor_event_csv_path: Optional[Path] = None + _anchor_lifecycle_step = 0 + _anchor_lifecycle_csv_path: Optional[Path] = None + _resident_anchor_keys: set[Tuple[str, str]] = set() + _resident_anchor_source: Optional[Tuple[str, int]] = None def __init__(self, llm: "LLMChat"): self.llm = llm self._warning_prefix = "[KVCOMMEngine]" + self.configure_resident_anchors_from_config() def _log_warning(self, message: str) -> None: logger.opt(colors=True).warning("{} {}", self._warning_prefix, message) + @classmethod + def configure_anchor_event_logging(cls, output_dir: Optional[Union[str, Path]]) -> None: + """Configure output csv path for KVReuse anchor diagnostics.""" + if output_dir is None: + return + out_dir = Path(output_dir).expanduser() + out_dir.mkdir(parents=True, exist_ok=True) + with cls._anchor_event_lock: + cls._anchor_event_csv_path = out_dir / "kvreuse_anchor_events.csv" + cls._anchor_lifecycle_csv_path = out_dir / "kvreuse_anchor_lifecycle.csv" + + @classmethod + def _next_anchor_event_step(cls) -> int: + with cls._anchor_event_lock: + cls._anchor_event_step += 1 + return cls._anchor_event_step + + @classmethod + def _write_anchor_event_rows(cls, rows: List[Dict[str, Any]]) -> None: + csv_path = cls._anchor_event_csv_path + if csv_path is None or not rows: + return + fieldnames = [ + "step", + "request_uid", + "ph_id", + "message", + "anchor_msg", + "is_candidate", + "is_selected", + "sim_score", + "weight", + "skip_reason", + "placeholder_len", + "available_anchor_num", + "selected_anchor_num", + ] + with cls._anchor_event_lock: + file_exists = csv_path.exists() + with csv_path.open("a", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + for row in rows: + writer.writerow({k: row.get(k, "") for k in fieldnames}) + + @classmethod + def _next_anchor_lifecycle_step(cls) -> int: + with cls._anchor_event_lock: + cls._anchor_lifecycle_step += 1 + return cls._anchor_lifecycle_step + + @classmethod + def _write_anchor_lifecycle_row(cls, row: Dict[str, Any]) -> None: + csv_path = cls._anchor_lifecycle_csv_path + if csv_path is None: + return + fieldnames = [ + "lifecycle_step", + "event", + "request_uid", + "ph_id", + "message", + "node_id", + "role", + "frequency", + "placeholder_len", + "accumulate_len", + "anchor_count_before", + "anchor_count_after", + "reason", + ] + with cls._anchor_event_lock: + file_exists = csv_path.exists() + with csv_path.open("a", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + writer.writerow({k: row.get(k, "") for k in fieldnames}) + + def _anchor_storage_device(self) -> torch.device: + device = getattr(self.llm, "anchor_device", None) + if device is None: + return torch.device("cpu") + return torch.device(device) + + def _compute_device(self) -> torch.device: + return torch.device(self.llm.model.device) + + def _materialize_anchor_entries( + self, + anchor_entries: List[Dict[str, Any]], + *, + device: Optional[Union[str, torch.device]] = None, + ) -> List[Dict[str, Any]]: + target_device = torch.device(device) if device is not None else self._compute_device() + return [_move_tensor_tree(entry, target_device) for entry in anchor_entries] + + def _offload_anchor_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: + return _move_tensor_tree(entry, self._anchor_storage_device()) + + @classmethod + def configure_resident_anchors( + cls, + summary_path: Optional[Union[str, Path]], + top_n: int, + ) -> None: + """Load the fixed hot-anchor set that should stay resident on GPU.""" + if not summary_path or top_n <= 0: + cls._resident_anchor_keys = set() + cls._resident_anchor_source = None + return + + path = Path(summary_path).expanduser() + source = (str(path), int(top_n)) + if cls._resident_anchor_source == source: + return + + rows: List[Dict[str, Any]] = [] + try: + with path.open("r", encoding="utf-8", newline="") as handle: + reader = csv.DictReader(handle) + for row in reader: + rows.append(row) + except OSError as exc: + logger.opt(colors=True).warning( + "[KVCOMMEngine] Failed to load resident anchor summary {}: {}", + path, + exc, + ) + cls._resident_anchor_keys = set() + cls._resident_anchor_source = source + return + + def _selected_count(row: Dict[str, Any]) -> int: + try: + return int(float(row.get("selected_count", 0) or 0)) + except (TypeError, ValueError): + return 0 + + rows.sort(key=_selected_count, reverse=True) + selected = rows[:top_n] + cls._resident_anchor_keys = { + (row.get("ph_id", ""), row.get("anchor_msg", "")) + for row in selected + if row.get("ph_id") and row.get("anchor_msg") + } + cls._resident_anchor_source = source + logger.opt(colors=True).info( + "[KVCOMMEngine] Loaded {} resident hot anchors from {} (top_n={})", + len(cls._resident_anchor_keys), + path, + top_n, + ) + + def configure_resident_anchors_from_config(self) -> None: + config = getattr(self.llm, "config", None) + if config is None: + return + self.configure_resident_anchors( + getattr(config, "resident_anchor_summary", None), + int(getattr(config, "resident_anchor_top_n", 0) or 0), + ) + + def _is_resident_anchor(self, ph_id: str, message: str) -> bool: + return (ph_id, message) in self._resident_anchor_keys + + def _store_anchor_entry(self, ph_id: str, message: str, entry: Dict[str, Any]) -> Dict[str, Any]: + if self._is_resident_anchor(ph_id, message): + resident_entry = _move_tensor_tree(entry, self._compute_device()) + resident_entry["_kvcomm_resident_anchor"] = True + return resident_entry + return self._offload_anchor_entry(entry) + @staticmethod def _stack_cache_tensors(cache: DynamicCache) -> Tuple[torch.Tensor, torch.Tensor]: return torch.stack(cache.key_cache), torch.stack(cache.value_cache) @@ -788,7 +1002,7 @@ def _get_cached_anchor_weights( return None if entry.get("anchor_signature") != signature: return None - return entry + return _move_tensor_tree(entry, self._compute_device()) def _set_cached_anchor_weights( self, @@ -800,7 +1014,7 @@ def _set_cached_anchor_weights( """Store computed anchor weights for reuse within the same request.""" state = self.resolve_request_state(request_uid) bucket = state.weight_dict.setdefault(ph_id, {}) - bucket[message] = entry + bucket[message] = _move_tensor_tree(entry, self._anchor_storage_device()) @staticmethod def _select_anchor_indices(anchor_list: List[Dict[str, Any]], placeholder_len: int) -> List[int]: @@ -884,6 +1098,7 @@ def offset_kv_cache_pair( real_key_embedding, real_value_embedding = self._stack_cache_tensors(base_placeholder_cache) anchor_signature = self.anchor_signature(anchor_list) + anchor_list_on_device = self._materialize_anchor_entries(anchor_list) cache_entry = self._get_cached_anchor_weights( request_uid, @@ -893,7 +1108,7 @@ def offset_kv_cache_pair( ) if cache_entry is None: - anchor_index = self._select_anchor_indices(anchor_list, placeholder_len) + anchor_index = self._select_anchor_indices(anchor_list_on_device, placeholder_len) if not anchor_index: if anchor_list: self._log_warning( @@ -902,7 +1117,7 @@ def offset_kv_cache_pair( return base_placeholder_cache.copy(), base_prefix_cache.copy() cache_entry = self._compute_anchor_weight_entry( - anchor_list, + anchor_list_on_device, anchor_index, real_key_embedding, real_value_embedding, @@ -925,14 +1140,14 @@ def offset_kv_cache_pair( weights_value_for_placeholder = cache_entry["weights_value_for_placeholder"] prefix_key_delta_stack = torch.stack( - [anchor_list[i][f"{self.llm.node_id}_pf_key_delta"] for i in anchor_index] + [anchor_list_on_device[i][f"{self.llm.node_id}_pf_key_delta"] for i in anchor_index] ) layer_total_delta_key_for_prefix = ( weights_key_for_prefix * prefix_key_delta_stack ).sum(0) prefix_value_delta_stack = torch.stack( - [anchor_list[i][f"{self.llm.node_id}_pf_value_delta"] for i in anchor_index] + [anchor_list_on_device[i][f"{self.llm.node_id}_pf_value_delta"] for i in anchor_index] ) layer_total_value_delta_for_prefix = ( weights_value_for_prefix * prefix_value_delta_stack @@ -940,7 +1155,7 @@ def offset_kv_cache_pair( placeholder_key_delta_stack = torch.stack( [ - anchor_list[i][f"{self.llm.node_id}_ph_key_delta"][..., :placeholder_len, :] + anchor_list_on_device[i][f"{self.llm.node_id}_ph_key_delta"][..., :placeholder_len, :] for i in anchor_index ] ) @@ -950,7 +1165,7 @@ def offset_kv_cache_pair( placeholder_value_delta_stack = torch.stack( [ - anchor_list[i][f"{self.llm.node_id}_ph_value_delta"][..., :placeholder_len, :] + anchor_list_on_device[i][f"{self.llm.node_id}_ph_value_delta"][..., :placeholder_len, :] for i in anchor_index ] ) @@ -995,11 +1210,86 @@ def predict_as_anchor( anchor_len_list: List[Tuple[int, int]], anchor_activated_list: List[int], top_p: float = 0.9, + top_k: Optional[int] = None, entropy_eps: float = 1e-40, test_time: bool = False, + request_uid: str = "", + ph_id: str = "", + message: str = "", + anchor_labels: Optional[List[str]] = None, + log_events: bool = False, ) -> Tuple[bool, List[int]]: + step = self._next_anchor_event_step() if log_events else -1 + label_list = anchor_labels if anchor_labels is not None else [""] * len(anchor_kv_cache_list) + + def _emit_events( + *, + available: List[int], + selected: List[int], + sim_values: Optional[torch.Tensor], + skip_reason: str = "none", + placeholder_len: int = 0, + ) -> None: + if not log_events: + return + selected_set = set(selected) + rows: List[Dict[str, Any]] = [] + if not available: + rows.append( + { + "step": step, + "request_uid": request_uid, + "ph_id": ph_id, + "message": message, + "anchor_msg": "__none__", + "is_candidate": 0, + "is_selected": 0, + "sim_score": "", + "weight": "", + "skip_reason": skip_reason, + "placeholder_len": placeholder_len, + "available_anchor_num": 0, + "selected_anchor_num": 0, + } + ) + self._write_anchor_event_rows(rows) + return + + for idx, anchor_idx in enumerate(available): + sim_val = ( + float(sim_values[idx].detach().cpu().item()) + if sim_values is not None and idx < sim_values.shape[0] + else "" + ) + row_skip_reason = skip_reason if skip_reason != "none" else "none" + rows.append( + { + "step": step, + "request_uid": request_uid, + "ph_id": ph_id, + "message": message, + "anchor_msg": label_list[anchor_idx] if anchor_idx < len(label_list) else "", + "is_candidate": 1, + "is_selected": 1 if anchor_idx in selected_set else 0, + "sim_score": sim_val, + "weight": sim_val if anchor_idx in selected_set else "", + "skip_reason": row_skip_reason, + "placeholder_len": placeholder_len, + "available_anchor_num": len(available), + "selected_anchor_num": len(selected), + } + ) + self._write_anchor_event_rows(rows) + if len(anchor_kv_cache_list) in [0, 1]: + _emit_events( + available=[], + selected=[], + sim_values=None, + skip_reason="insufficient_anchors", + ) return True, anchor_activated_list + anchor_kv_cache_list = self._materialize_anchor_entries(anchor_kv_cache_list) if test_time: torch.cuda.synchronize() @@ -1012,6 +1302,13 @@ def predict_as_anchor( "The length of anchor_len_list is not equal to the length of anchor_available, " f"with {len(anchor_len_list)} and {len(anchor_available)}." ) + _emit_events( + available=[], + selected=[], + sim_values=None, + skip_reason="len_mismatch", + placeholder_len=k, + ) return True, anchor_activated_list if len(anchor_available) > 1: @@ -1028,6 +1325,13 @@ def predict_as_anchor( f"Entropy {entropy:.4f} exceeds threshold {threshold * torch.log2(torch.tensor(sim.shape[0])):.4f}, " "skip activating anchors." ) + _emit_events( + available=anchor_available, + selected=[], + sim_values=sim, + skip_reason="entropy_skip", + placeholder_len=k, + ) if test_time: torch.cuda.synchronize() end_time = perf_counter() @@ -1036,10 +1340,19 @@ def predict_as_anchor( ) return True, anchor_activated_list sorted_sim, sorted_indices = torch.sort(sim, descending=True) - cumulative_sum = torch.cumsum(sorted_sim, dim=0) - cutoff_index_candidates = (cumulative_sum < top_p).nonzero(as_tuple=True)[0] - cutoff_index = cutoff_index_candidates[-1] if len(cutoff_index_candidates) > 0 else len(sorted_sim) - 1 - selected_indices = sorted_indices[:cutoff_index + 1] + effective_top_k = self.llm.config.top_k if top_k is None else top_k + if effective_top_k is not None: + selected_indices = sorted_indices[: min(effective_top_k, len(sorted_indices))].tolist() + else: + cumulative_sum = torch.cumsum(sorted_sim, dim=0) + cutoff_index_candidates = (cumulative_sum < top_p).nonzero(as_tuple=True)[0] + cutoff_index = cutoff_index_candidates[-1] if len(cutoff_index_candidates) > 0 else len(sorted_sim) - 1 + selected_indices = sorted_indices[:cutoff_index + 1].tolist() + selected_anchor_indices = [ + anchor_available[i] + for i in selected_indices + if i < len(anchor_available) + ] for i in selected_indices: if anchor_available[i] >= len(anchor_activated_list): self._log_warning( @@ -1049,6 +1362,13 @@ def predict_as_anchor( ) continue anchor_activated_list[anchor_available[i]] += 1 + _emit_events( + available=anchor_available, + selected=selected_anchor_indices, + sim_values=sim, + skip_reason="none", + placeholder_len=k, + ) if test_time: torch.cuda.synchronize() end_time = perf_counter() @@ -1057,6 +1377,13 @@ def predict_as_anchor( ) return False, anchor_activated_list logger.opt(colors=True).debug("No available anchors to activate.") + _emit_events( + available=anchor_available, + selected=[], + sim_values=None, + skip_reason="no_available", + placeholder_len=k, + ) return True, anchor_activated_list def update_anchor(self, request_uid: str, ph_id: str, window_length: int = 5) -> None: @@ -1066,15 +1393,35 @@ def update_anchor(self, request_uid: str, ph_id: str, window_length: int = 5) -> state = self.resolve_request_state(request_uid) anchor_store = state.anchors.setdefault(ph_id, {}) anchor_info_dict = state.anchor_info_dict.setdefault(ph_id, {}) - info_list = list(anchor_info_dict.values())[:window_length] - if not info_list: + info_items = list(anchor_info_dict.items())[:window_length] + removable_items = [ + item + for item in info_items + if not self._is_resident_anchor(ph_id, item[0]) + ] + if not removable_items: return - min_idx = info_list.index(min(info_list)) - message = list(anchor_info_dict.keys())[min_idx] + message, _ = min(removable_items, key=lambda item: item[1]) + anchor_count_before = len(anchor_store) anchor_store.pop(message, None) state.anchor_len_dict.setdefault(ph_id, {}).pop(message, None) freq = anchor_info_dict.pop(message, None) state.global_anchor_info.setdefault(ph_id, {}).pop(message, None) + self._write_anchor_lifecycle_row( + { + "lifecycle_step": self._next_anchor_lifecycle_step(), + "event": "remove", + "request_uid": request_uid, + "ph_id": ph_id, + "message": message, + "node_id": self.llm.node_id, + "role": self.llm.role, + "frequency": freq, + "anchor_count_before": anchor_count_before, + "anchor_count_after": len(anchor_store), + "reason": f"low_frequency_in_oldest_window_{window_length}", + } + ) self._log_warning( f"Removed anchor for message '{message}' in {self.llm.node_id} ({self.llm.role}) due to low frequency: {freq}" ) @@ -1143,12 +1490,13 @@ def _make_anchor(i, ph_id, real_ph, base_ph, real_pf, base_pf): if i not in anchor_dict: accumulate_len += placeholder_len continue - entry = anchor_dict[i] + entry = self._store_anchor_entry(ph_id_list[i], message, anchor_dict[i]) over_store = len(anchor_store.setdefault(ph_id_list[i], {})) > max_anchor_num if over_store: self.update_anchor(request_uid, ph_id_list[i], window_length) if message not in anchor_store[ph_id_list[i]]: + anchor_count_before = len(anchor_store[ph_id_list[i]]) anchor_store[ph_id_list[i]][message] = entry info_bucket = state.anchor_info_dict.setdefault(ph_id_list[i], {}) info_bucket[message] = 0 @@ -1163,6 +1511,23 @@ def _make_anchor(i, ph_id, real_ph, base_ph, real_pf, base_pf): message, [0, placeholder_len], ) + self._write_anchor_lifecycle_row( + { + "lifecycle_step": self._next_anchor_lifecycle_step(), + "event": "add", + "request_uid": request_uid, + "ph_id": ph_id_list[i], + "message": message, + "node_id": self.llm.node_id, + "role": self.llm.role, + "frequency": 0, + "placeholder_len": placeholder_len, + "accumulate_len": accumulate_len, + "anchor_count_before": anchor_count_before, + "anchor_count_after": len(anchor_store[ph_id_list[i]]), + "reason": "new_anchor", + } + ) else: anchor_store[ph_id_list[i]][message].update(entry) accumulate_len += placeholder_len @@ -1222,8 +1587,8 @@ def fetch_shared_cache( if "user_question" in ph_id: return ( - shared_memory["input"][message][-1], - shared_memory["input_ids"][message][-1], + _move_tensor_tree(shared_memory["input"][message][-1], self._compute_device()), + _move_tensor_tree(shared_memory["input_ids"][message][-1], self._compute_device()), shared_memory["input_drop_num"][message][-1], ) @@ -1254,7 +1619,11 @@ def _get_slot(bucket_key: str): f"fetch_shared_cache: placeholder {ph_id} for message='{message}' not found." ) - return ph_cache, ph_cache_ids, drop_num + return ( + _move_tensor_tree(ph_cache, self._compute_device()), + _move_tensor_tree(ph_cache_ids, self._compute_device()), + drop_num, + ) @staticmethod def trim_token_ids(ids_dict: Dict[str, torch.Tensor], drop_num: int) -> Dict[str, torch.Tensor]: diff --git a/KVCOMM/utils/gpu_debug.py b/KVCOMM/utils/gpu_debug.py new file mode 100644 index 0000000..4b929c3 --- /dev/null +++ b/KVCOMM/utils/gpu_debug.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import csv +import gc +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +import torch +from transformers.cache_utils import DynamicCache + +DEFAULT_TARGET_DEVICE = 0 + + +def _tensor_nbytes(tensor: torch.Tensor) -> int: + return tensor.numel() * tensor.element_size() + + +def _is_scoped_dict_like(obj: Any) -> bool: + return hasattr(obj, "_base") and hasattr(obj, "_local") + + +def _walk_cuda_tensors( + obj: Any, + *, + path: str, + target_device: Optional[int], + seen_objects: Set[int], + seen_tensors: Set[int], + out: List[Dict[str, Any]], +) -> None: + obj_id = id(obj) + if obj_id in seen_objects: + return + seen_objects.add(obj_id) + + if isinstance(obj, torch.Tensor): + if obj.is_cuda and (target_device is None or obj.device.index == target_device): + ptr = obj.data_ptr() + if ptr not in seen_tensors: + seen_tensors.add(ptr) + out.append( + { + "path": path, + "shape": tuple(obj.shape), + "dtype": str(obj.dtype), + "device": str(obj.device), + "mb": _tensor_nbytes(obj) / 1024**2, + } + ) + return + + if isinstance(obj, DynamicCache): + if hasattr(obj, "layers"): + for idx, layer in enumerate(getattr(obj, "layers", [])): + _walk_cuda_tensors( + getattr(layer, "keys", None), + path=f"{path}.layers[{idx}].keys", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + _walk_cuda_tensors( + getattr(layer, "values", None), + path=f"{path}.layers[{idx}].values", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + else: + for idx, key in enumerate(getattr(obj, "key_cache", [])): + _walk_cuda_tensors( + key, + path=f"{path}.key_cache[{idx}]", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + for idx, value in enumerate(getattr(obj, "value_cache", [])): + _walk_cuda_tensors( + value, + path=f"{path}.value_cache[{idx}]", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + return + + if isinstance(obj, dict): + for key, value in obj.items(): + _walk_cuda_tensors( + value, + path=f"{path}[{repr(key)}]", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + return + + if isinstance(obj, (list, tuple, set)): + for idx, value in enumerate(obj): + _walk_cuda_tensors( + value, + path=f"{path}[{idx}]", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + return + + if _is_scoped_dict_like(obj): + _walk_cuda_tensors( + getattr(obj, "_base", None), + path=f"{path}._base", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + _walk_cuda_tensors( + getattr(obj, "_local", None), + path=f"{path}._local", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + return + + if hasattr(obj, "__dict__"): + for key, value in vars(obj).items(): + _walk_cuda_tensors( + value, + path=f"{path}.{key}", + target_device=target_device, + seen_objects=seen_objects, + seen_tensors=seen_tensors, + out=out, + ) + + +def collect_cuda_tensor_info( + name: str, + obj: Any, + *, + target_device: Optional[int] = DEFAULT_TARGET_DEVICE, +) -> Dict[str, Any]: + entries: List[Dict[str, Any]] = [] + _walk_cuda_tensors( + obj, + path=name, + target_device=target_device, + seen_objects=set(), + seen_tensors=set(), + out=entries, + ) + entries.sort(key=lambda item: item["mb"], reverse=True) + total_mb = sum(item["mb"] for item in entries) + return { + "name": name, + "count": len(entries), + "total_mb": total_mb, + "entries": entries, + } + + +def current_cuda_memory(device: Optional[int] = DEFAULT_TARGET_DEVICE) -> Dict[str, float]: + if not torch.cuda.is_available(): + return {"allocated_mb": 0.0, "reserved_mb": 0.0, "max_allocated_mb": 0.0} + device_arg = torch.device(f"cuda:{device}") if device is not None else None + return { + "allocated_mb": torch.cuda.memory_allocated(device_arg) / 1024**2, + "reserved_mb": torch.cuda.memory_reserved(device_arg) / 1024**2, + "max_allocated_mb": torch.cuda.max_memory_allocated(device_arg) / 1024**2, + } + + +def reset_cuda_peak_memory(device: Optional[int] = DEFAULT_TARGET_DEVICE) -> None: + """Reset CUDA peak stats so later max_allocated reflects a local interval.""" + if not torch.cuda.is_available(): + return + device_arg = torch.device(f"cuda:{device}") if device is not None else None + torch.cuda.reset_peak_memory_stats(device_arg) + + +def summarize_kvcomm_cuda_state( + *, + topk: int = 20, + include_gc_tensors: bool = False, + target_device: Optional[int] = DEFAULT_TARGET_DEVICE, +) -> Dict[str, Any]: + from KVCOMM.llm.gpt_chat import LLMChat + from KVCOMM.llm.kvcomm_engine import KVCOMMEngine + + sections = [ + ("anchors", KVCOMMEngine.anchors), + ("weight_dict", KVCOMMEngine.weight_dict), + ("request_states", KVCOMMEngine._request_states), + ("staged_commits", KVCOMMEngine._staged_commits), + ("shared_kv_cache_memory", LLMChat._shared_kv_cache_memory), + ] + + report = { + "memory": current_cuda_memory(target_device), + "target_device": target_device, + "sections": [], + } + for name, obj in sections: + section = collect_cuda_tensor_info(name, obj, target_device=target_device) + section["entries"] = section["entries"][:topk] + report["sections"].append(section) + + if include_gc_tensors: + gc_entries: List[Dict[str, Any]] = [] + seen_ptrs: Set[int] = set() + for obj in gc.get_objects(): + try: + if ( + isinstance(obj, torch.Tensor) + and obj.is_cuda + and (target_device is None or obj.device.index == target_device) + ): + ptr = obj.data_ptr() + if ptr in seen_ptrs: + continue + seen_ptrs.add(ptr) + gc_entries.append( + { + "shape": tuple(obj.shape), + "dtype": str(obj.dtype), + "device": str(obj.device), + "mb": _tensor_nbytes(obj) / 1024**2, + } + ) + except Exception: + continue + gc_entries.sort(key=lambda item: item["mb"], reverse=True) + report["gc_cuda_tensors"] = gc_entries[:topk] + + return report + + +def append_kvcomm_cuda_state_csv( + output_dir: Optional[str | Path], + *, + tag: str, + batch_index: Optional[int] = None, + phase: Optional[str] = None, + topk: int = 0, + include_gc_tensors: bool = False, + target_device: Optional[int] = DEFAULT_TARGET_DEVICE, + filename: str = "CudaMemory.csv", +) -> None: + """Append a structured CUDA memory snapshot for experiment comparisons.""" + if output_dir is None: + return + if torch.cuda.is_available(): + torch.cuda.synchronize() + + report = summarize_kvcomm_cuda_state( + topk=topk, + include_gc_tensors=include_gc_tensors, + target_device=target_device, + ) + memory = report["memory"] + row: Dict[str, Any] = { + "timestamp": time.time(), + "tag": tag, + "batch_index": batch_index if batch_index is not None else "", + "phase": phase or "", + "target_device": "" if target_device is None else target_device, + "allocated_mb": memory["allocated_mb"], + "reserved_mb": memory["reserved_mb"], + "max_allocated_mb": memory["max_allocated_mb"], + } + for section in report["sections"]: + name = section["name"] + row[f"{name}_cuda_tensor_count"] = section["count"] + row[f"{name}_cuda_tensor_mb"] = section["total_mb"] + + csv_path = Path(output_dir).expanduser() / filename + csv_path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "timestamp", + "tag", + "batch_index", + "phase", + "target_device", + "allocated_mb", + "reserved_mb", + "max_allocated_mb", + "anchors_cuda_tensor_count", + "anchors_cuda_tensor_mb", + "weight_dict_cuda_tensor_count", + "weight_dict_cuda_tensor_mb", + "request_states_cuda_tensor_count", + "request_states_cuda_tensor_mb", + "staged_commits_cuda_tensor_count", + "staged_commits_cuda_tensor_mb", + "shared_kv_cache_memory_cuda_tensor_count", + "shared_kv_cache_memory_cuda_tensor_mb", + ] + file_exists = csv_path.exists() + if file_exists: + with csv_path.open("r", encoding="utf-8", newline="") as handle: + reader = csv.DictReader(handle) + existing_fieldnames = reader.fieldnames or [] + existing_rows = list(reader) + if existing_fieldnames and existing_fieldnames != fieldnames: + with csv_path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for existing_row in existing_rows: + writer.writerow({key: existing_row.get(key, "") for key in fieldnames}) + with csv_path.open("a", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + writer.writerow({key: row.get(key, "") for key in fieldnames}) + + +def print_kvcomm_cuda_state( + *, + tag: Optional[str] = None, + topk: int = 20, + include_gc_tensors: bool = False, + target_device: Optional[int] = DEFAULT_TARGET_DEVICE, +) -> None: + report = summarize_kvcomm_cuda_state( + topk=topk, + include_gc_tensors=include_gc_tensors, + target_device=target_device, + ) + prefix = f"[{tag}] " if tag else "" + memory = report["memory"] + device_label = ( + f"cuda:{report['target_device']}" + if report["target_device"] is not None + else "all_cuda_devices" + ) + print( + f"{prefix}CUDA memory ({device_label}): " + f"allocated={memory['allocated_mb']:.1f}MB " + f"reserved={memory['reserved_mb']:.1f}MB " + f"max_allocated={memory['max_allocated_mb']:.1f}MB" + ) + for section in report["sections"]: + print( + f"{prefix}{section['name']}: " + f"cuda_tensors={section['count']} total={section['total_mb']:.1f}MB" + ) + for item in section["entries"]: + print( + f"{prefix} - {item['path']} " + f"shape={item['shape']} dtype={item['dtype']} " + f"device={item['device']} size={item['mb']:.1f}MB" + ) + if include_gc_tensors: + gc_entries = report.get("gc_cuda_tensors", []) + print(f"{prefix}gc_cuda_tensors: {len(gc_entries)} shown") + for item in gc_entries: + print( + f"{prefix} - shape={item['shape']} dtype={item['dtype']} " + f"device={item['device']} size={item['mb']:.1f}MB" + ) diff --git a/datasets/MMLU/download.py b/datasets/MMLU/download.py index c043db3..1f4cd7a 100644 --- a/datasets/MMLU/download.py +++ b/datasets/MMLU/download.py @@ -1,26 +1,166 @@ import os +import shutil +from typing import Iterator + import requests import tarfile +from requests.adapters import HTTPAdapter +from tqdm import tqdm +from urllib3.util.retry import Retry + from KVCOMM.utils.log import logger +# (connect timeout, read timeout between chunks) — large tar needs a generous read timeout +_REQUEST_TIMEOUT = (30, 600) + +# Hugging Face hosts the same Hendrycks archive; CDN is often much faster than Berkeley direct. +_HF_MMLU_TAR = ( + "https://huggingface.co/datasets/Stevross/mmlu/resolve/main/data.tar" +) +_BERKELEY_MMLU_TAR = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" + +_HF_HUB_REPO = "Stevross/mmlu" +_HF_HUB_FILENAME = "data.tar" + + +def _tar_url_candidates() -> Iterator[str]: + """URLs to try, in order. Set MMLU_DATA_TAR_URL to force a single mirror.""" + custom = os.environ.get("MMLU_DATA_TAR_URL", "").strip() + if custom: + yield custom + return + yield _HF_MMLU_TAR + yield _BERKELEY_MMLU_TAR + + +def _request_headers() -> dict: + token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + if token: + return {"Authorization": f"Bearer {token}"} + return {} + + +def _requests_session() -> requests.Session: + """Session with retries for flaky TLS / CDN resets.""" + s = requests.Session() + retries = Retry( + total=8, + connect=8, + read=8, + backoff_factor=1.2, + status_forcelist=(429, 500, 502, 503, 504), + allowed_methods=frozenset(["GET", "HEAD"]), + ) + adapter = HTTPAdapter(max_retries=retries) + s.mount("https://", adapter) + s.mount("http://", adapter) + return s + + +def _download_via_hf_hub(tar_path: str) -> bool: + """Use huggingface_hub (respects HF_ENDPOINT, token, hub cache; more robust than raw GET).""" + try: + from huggingface_hub import hf_hub_download + except ImportError: + return False + try: + cached = hf_hub_download( + repo_id=_HF_HUB_REPO, + filename=_HF_HUB_FILENAME, + repo_type="dataset", + ) + if os.path.abspath(cached) == os.path.abspath(tar_path): + return True + shutil.copy2(cached, tar_path) + return True + except Exception as exc: + logger.warning("huggingface_hub download failed: {}", exc) + return False + + +def _stream_download(url: str, tar_path: str) -> None: + expected_size = None + session = _requests_session() + with session.get( + url, + allow_redirects=True, + stream=True, + timeout=_REQUEST_TIMEOUT, + headers=_request_headers(), + ) as r: + r.raise_for_status() + cl = r.headers.get("Content-Length") + expected_size = int(cl) if cl is not None else None + chunk_size = 1024 * 1024 + with open(tar_path, "wb") as f, tqdm( + desc="data.tar", + total=expected_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + ) as pbar: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(len(data)) + if expected_size is not None: + got = os.path.getsize(tar_path) + if got != expected_size: + raise OSError( + f"Incomplete download: expected {expected_size} bytes, got {got}" + ) + def download(): this_file_path = os.path.split(__file__)[0] tar_path = os.path.join(this_file_path, "data.tar") if not os.path.exists(tar_path): - url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" - logger.info("Downloading {}", url) - r = requests.get(url, allow_redirects=True) - with open(tar_path, 'wb') as f: - f.write(r.content) - logger.info("Saved to {}", tar_path) + custom_url = os.environ.get("MMLU_DATA_TAR_URL", "").strip() + ok = False + if not custom_url: + logger.info( + "Trying huggingface_hub (repo={}, file={}; set HF_ENDPOINT for mirrors)", + _HF_HUB_REPO, + _HF_HUB_FILENAME, + ) + ok = _download_via_hf_hub(tar_path) + if ok: + logger.info("Saved to {}", tar_path) + else: + last_error: Exception | None = None + for url in _tar_url_candidates(): + logger.info("Downloading {}", url) + try: + _stream_download(url, tar_path) + last_error = None + break + except Exception as exc: + last_error = exc + if os.path.exists(tar_path): + os.unlink(tar_path) + logger.warning("Download from {} failed: {}", url, exc) + if last_error is not None: + raise last_error + logger.info("Saved to {}", tar_path) data_path = os.path.join(this_file_path, "data") if not os.path.exists(data_path): - tar = tarfile.open(tar_path) - tar.extractall(this_file_path) - tar.close() + try: + with tarfile.open(tar_path) as tar: + tar.extractall(this_file_path) + except (tarfile.ReadError, tarfile.TarError, EOFError) as exc: + if os.path.isdir(data_path): + shutil.rmtree(data_path, ignore_errors=True) + if os.path.exists(tar_path): + os.unlink(tar_path) + logger.error( + "data.tar is corrupt or truncated ({}). Removed archive and partial " + "'data/'; run again to re-download.", + exc, + ) + raise logger.info("Saved to {}", data_path) diff --git a/experiments/analyze_kvreuse_anchor.py b/experiments/analyze_kvreuse_anchor.py new file mode 100644 index 0000000..81e2b2f --- /dev/null +++ b/experiments/analyze_kvreuse_anchor.py @@ -0,0 +1,391 @@ +import argparse +import csv +from collections import Counter, defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt + + +def _to_int(value: str, default: int = 0) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _to_float(value: str, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def load_events(csv_path: Path) -> List[Dict[str, str]]: + with csv_path.open("r", encoding="utf-8", newline="") as f: + reader = csv.DictReader(f) + rows = list(reader) + if not rows: + raise ValueError(f"Input CSV is empty: {csv_path}") + return rows + + +def build_summary( + rows: List[Dict[str, str]], +) -> Tuple[List[Dict[str, str]], Counter, int]: + grouped: Dict[Tuple[str, str], Dict[str, float]] = {} + skip_reason_counter: Counter = Counter() + max_step = 0 + + for row in rows: + ph_id = row.get("ph_id", "") + anchor_msg = row.get("anchor_msg", "") + key = (ph_id, anchor_msg) + if key not in grouped: + grouped[key] = { + "candidate_count": 0.0, + "selected_count": 0.0, + "last_selected_step": -1.0, + "weight_sum": 0.0, + "weight_n": 0.0, + } + + bucket = grouped[key] + is_candidate = _to_int(row.get("is_candidate", "0")) + is_selected = _to_int(row.get("is_selected", "0")) + step = _to_int(row.get("step", "0")) + weight = _to_float(row.get("weight", "")) + skip_reason = row.get("skip_reason", "none") or "none" + + bucket["candidate_count"] += is_candidate + bucket["selected_count"] += is_selected + if is_selected: + bucket["last_selected_step"] = max(bucket["last_selected_step"], step) + bucket["weight_sum"] += weight + bucket["weight_n"] += 1 + + if skip_reason != "none": + skip_reason_counter[skip_reason] += 1 + max_step = max(max_step, step) + + summary_rows: List[Dict[str, str]] = [] + for (ph_id, anchor_msg), bucket in grouped.items(): + candidate_count = int(bucket["candidate_count"]) + selected_count = int(bucket["selected_count"]) + last_selected_step = int(bucket["last_selected_step"]) + selected_rate = (selected_count / candidate_count) if candidate_count > 0 else 0.0 + idle_steps = (max_step - last_selected_step) if last_selected_step >= 0 else max_step + avg_weight = (bucket["weight_sum"] / bucket["weight_n"]) if bucket["weight_n"] > 0 else 0.0 + summary_rows.append( + { + "ph_id": ph_id, + "anchor_msg": anchor_msg, + "candidate_count": str(candidate_count), + "selected_count": str(selected_count), + "selected_rate": f"{selected_rate:.6f}", + "last_selected_step": str(last_selected_step), + "idle_steps": str(int(idle_steps)), + "avg_weight": f"{avg_weight:.6f}", + } + ) + + summary_rows.sort(key=lambda x: int(x["selected_count"]), reverse=True) + return summary_rows, skip_reason_counter, max_step + + +def save_summary_csv(rows: List[Dict[str, str]], output_path: Path) -> None: + fields = [ + "ph_id", + "anchor_msg", + "candidate_count", + "selected_count", + "selected_rate", + "last_selected_step", + "idle_steps", + "avg_weight", + ] + with output_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fields) + writer.writeheader() + writer.writerows(rows) + + +def plot_topn_bar(summary_rows: List[Dict[str, str]], output_path: Path, top_n: int) -> None: + top = summary_rows[:top_n] + labels = [f"{r['ph_id']}#{i}" for i, r in enumerate(top)] + values = [int(r["selected_count"]) for r in top] + + plt.figure(figsize=(12, 5)) + plt.bar(range(len(values)), values) + plt.xticks(range(len(values)), labels, rotation=45, ha="right") + plt.ylabel("selected_count") + plt.title(f"KVReuse Anchor Hotness Top-{top_n}") + plt.tight_layout() + plt.savefig(output_path, dpi=180) + plt.close() + + +def plot_lorenz(summary_rows: List[Dict[str, str]], output_path: Path) -> None: + counts = [int(r["selected_count"]) for r in summary_rows] + if not counts: + counts = [0] + counts = sorted(counts) + total = sum(counts) + + x, y = [0.0], [0.0] + running = 0 + n = len(counts) + for i, c in enumerate(counts, start=1): + running += c + x.append(i / n) + y.append((running / total) if total > 0 else 0.0) + + plt.figure(figsize=(6, 6)) + plt.plot(x, y, label="Observed") + plt.plot([0, 1], [0, 1], "--", label="Uniform") + plt.xlabel("Cumulative share of anchors") + plt.ylabel("Cumulative share of selections") + plt.title("KVReuse Anchor Concentration (Lorenz Curve)") + plt.legend() + plt.tight_layout() + plt.savefig(output_path, dpi=180) + plt.close() + + +def plot_rolling_concentration( + rows: List[Dict[str, str]], + output_path: Path, + rolling_window: int, +) -> None: + selected_rows = [r for r in rows if _to_int(r.get("is_selected", "0")) == 1] + if not selected_rows: + plt.figure(figsize=(12, 5)) + plt.title("No selected anchors found") + plt.tight_layout() + plt.savefig(output_path, dpi=180) + plt.close() + return + + selected_rows.sort(key=lambda r: _to_int(r.get("step", "0"))) + step_anchor_counts: Dict[int, Counter] = defaultdict(Counter) + for r in selected_rows: + step = _to_int(r.get("step", "0")) + anchor_key = f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}" + step_anchor_counts[step][anchor_key] += 1 + + steps = sorted(step_anchor_counts.keys()) + top1_series, top5_series = [], [] + for idx, step in enumerate(steps): + left = max(0, idx - rolling_window + 1) + window_steps = steps[left : idx + 1] + merged = Counter() + for s in window_steps: + merged.update(step_anchor_counts[s]) + total = sum(merged.values()) + if total == 0: + top1_series.append(0.0) + top5_series.append(0.0) + continue + freqs = sorted(merged.values(), reverse=True) + top1_series.append(freqs[0] / total) + top5_series.append(sum(freqs[:5]) / total) + + plt.figure(figsize=(12, 5)) + plt.plot(steps, top1_series, label="Top1 share") + plt.plot(steps, top5_series, label="Top5 share") + plt.ylim(0.0, 1.0) + plt.xlabel("step") + plt.ylabel("share in rolling window") + plt.title(f"KVReuse Rolling Anchor Concentration (window={rolling_window})") + plt.legend() + plt.tight_layout() + plt.savefig(output_path, dpi=180) + plt.close() + + +def plot_global_hot_anchor_trends( + rows: List[Dict[str, str]], + summary_rows: List[Dict[str, str]], + output_path: Path, + top_k: int, +) -> None: + top_anchors = [ + f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}" + for r in summary_rows + if int(r["selected_count"]) > 0 + ][:top_k] + + if not top_anchors: + plt.figure(figsize=(12, 5)) + plt.title("No selected anchors found") + plt.tight_layout() + plt.savefig(output_path, dpi=180) + plt.close() + return + + selected_rows = [r for r in rows if _to_int(r.get("is_selected", "0")) == 1] + selected_rows.sort(key=lambda r: _to_int(r.get("step", "0"))) + + steps = sorted({_to_int(r.get("step", "0")) for r in selected_rows}) + per_step_counts: Dict[str, Counter] = {anchor: Counter() for anchor in top_anchors} + for r in selected_rows: + anchor_key = f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}" + if anchor_key in per_step_counts: + per_step_counts[anchor_key][_to_int(r.get("step", "0"))] += 1 + + plt.figure(figsize=(12, 5)) + for rank, anchor_key in enumerate(top_anchors, start=1): + running = 0 + cumulative = [] + for step in steps: + running += per_step_counts[anchor_key][step] + cumulative.append(running) + ph_id, _ = anchor_key.split("||", 1) + plt.plot(steps, cumulative, label=f"Top{rank}: {ph_id}") + + plt.xlabel("step") + plt.ylabel("cumulative selected_count") + plt.title(f"KVReuse Global Hot Anchor Trends (top={len(top_anchors)})") + plt.legend() + plt.tight_layout() + plt.savefig(output_path, dpi=180) + plt.close() + + +def plot_global_hot_anchor_rolling_share( + rows: List[Dict[str, str]], + summary_rows: List[Dict[str, str]], + output_path: Path, + top_k: int, + rolling_window: int, +) -> None: + top_anchors = [ + f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}" + for r in summary_rows + if int(r["selected_count"]) > 0 + ][:top_k] + + selected_rows = [r for r in rows if _to_int(r.get("is_selected", "0")) == 1] + if not selected_rows or not top_anchors: + plt.figure(figsize=(12, 5)) + plt.title("No selected anchors found") + plt.tight_layout() + plt.savefig(output_path, dpi=180) + plt.close() + return + + selected_rows.sort(key=lambda r: _to_int(r.get("step", "0"))) + step_anchor_counts: Dict[int, Counter] = defaultdict(Counter) + for r in selected_rows: + step = _to_int(r.get("step", "0")) + anchor_key = f"{r.get('ph_id', '')}||{r.get('anchor_msg', '')}" + step_anchor_counts[step][anchor_key] += 1 + + steps = sorted(step_anchor_counts.keys()) + topk_total_series: List[float] = [] + + for idx, _step in enumerate(steps): + left = max(0, idx - rolling_window + 1) + window_steps = steps[left : idx + 1] + merged = Counter() + for s in window_steps: + merged.update(step_anchor_counts[s]) + + total = sum(merged.values()) + if total == 0: + topk_total_series.append(0.0) + continue + + topk_total = sum(merged[anchor] for anchor in top_anchors) + topk_total_series.append(topk_total / total) + + plt.figure(figsize=(12, 5)) + plt.plot(steps, topk_total_series, linewidth=2.0, label=f"Top{len(top_anchors)} total") + + plt.ylim(0.0, 1.0) + plt.xlabel("step") + plt.ylabel("share in rolling window") + plt.title( + f"KVReuse Global Hot Anchor Rolling Share " + f"(top={len(top_anchors)}, window={rolling_window})" + ) + plt.legend() + plt.tight_layout() + plt.savefig(output_path, dpi=180) + plt.close() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Analyze KVReuse anchor hot/cold behavior.") + parser.add_argument( + "--events-csv", + type=str, + required=True, + help="Path to kvreuse_anchor_events.csv", + ) + parser.add_argument( + "--output-dir", + type=str, + default="runs/anchor_analysis", + help="Directory to save CSV and figures.", + ) + parser.add_argument( + "--top-n", + type=int, + default=20, + help="Top-N anchors shown in bar chart.", + ) + parser.add_argument( + "--rolling-window", + type=int, + default=200, + help="Rolling window size (in step bins) for concentration chart.", + ) + parser.add_argument( + "--trend-top-k", + type=int, + default=5, + help="Global Top-K selected anchors shown in cumulative trend chart.", + ) + args = parser.parse_args() + + events_csv = Path(args.events_csv).expanduser().resolve() + output_dir = Path(args.output_dir).expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + rows = load_events(events_csv) + summary_rows, skip_reasons, max_step = build_summary(rows) + + summary_csv = output_dir / "kvreuse_anchor_summary.csv" + save_summary_csv(summary_rows, summary_csv) + + plot_topn_bar(summary_rows, output_dir / "kvreuse_topn_hotness.png", args.top_n) + plot_lorenz(summary_rows, output_dir / "kvreuse_lorenz_curve.png") + plot_rolling_concentration(rows, output_dir / "kvreuse_rolling_concentration.png", args.rolling_window) + plot_global_hot_anchor_trends( + rows, + summary_rows, + output_dir / "kvreuse_global_hot_anchor_trends.png", + args.trend_top_k, + ) + plot_global_hot_anchor_rolling_share( + rows, + summary_rows, + output_dir / "kvreuse_global_hot_anchor_rolling_share.png", + args.trend_top_k, + args.rolling_window, + ) + + print(f"[done] summary csv: {summary_csv}") + print(f"[done] figures: {output_dir / 'kvreuse_topn_hotness.png'}") + print(f"[done] figures: {output_dir / 'kvreuse_lorenz_curve.png'}") + print(f"[done] figures: {output_dir / 'kvreuse_rolling_concentration.png'}") + print(f"[done] figures: {output_dir / 'kvreuse_global_hot_anchor_trends.png'}") + print(f"[done] figures: {output_dir / 'kvreuse_global_hot_anchor_rolling_share.png'}") + print(f"[info] max_step={max_step}, anchors={len(summary_rows)}") + if skip_reasons: + print(f"[info] skip_reasons={dict(skip_reasons)}") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_gsm8k.py b/experiments/run_gsm8k.py index 72d0a4b..7dce421 100644 --- a/experiments/run_gsm8k.py +++ b/experiments/run_gsm8k.py @@ -17,6 +17,7 @@ from KVCOMM.llm.config import KVCommConfig from KVCOMM.tools.reader.readers import JSONLReader from KVCOMM.utils.globals import Time +from KVCOMM.utils.gpu_debug import print_kvcomm_cuda_state from KVCOMM.utils.log import configure_logging, logger from KVCOMM.utils.metrics import metrics_recorder from datasets.gsm8k_dataset import ( @@ -109,7 +110,8 @@ async def main(): current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) Time.instance().value = current_time - result_file = output_dir / f"{args.domain}_{args.llm_name}_{current_time}.json" + llm_tag = args.llm_name.replace("/", "__") + result_file = output_dir / f"{args.domain}_{llm_tag}_{current_time}.json" latency_target = str(output_dir) agent_names = [name for name, num in zip(args.agent_names, args.agent_nums) for _ in range(num)] @@ -226,6 +228,7 @@ async def main(): ) logger.opt(colors=True).info(f"[ACCURACY] {accuracy:.4f}") metrics_recorder.log_cumulative(batch_index=i_batch) + print_kvcomm_cuda_state(tag=f"batch_{i_batch}_end", topk=10) def get_kwargs( diff --git a/experiments/run_humaneval.py b/experiments/run_humaneval.py index 647fc50..532555f 100644 --- a/experiments/run_humaneval.py +++ b/experiments/run_humaneval.py @@ -18,6 +18,11 @@ from KVCOMM.tools.coding.python_executor import PyExecutor from KVCOMM.tools.reader.readers import JSONLReader from KVCOMM.utils.globals import Time +from KVCOMM.utils.gpu_debug import ( + append_kvcomm_cuda_state_csv, + print_kvcomm_cuda_state, + reset_cuda_peak_memory, +) from KVCOMM.utils.log import configure_logging, logger from KVCOMM.utils.metrics import metrics_recorder @@ -84,6 +89,27 @@ def parse_args(): parser.add_argument("--kv-window-size", type=int, default=None, help="Window size for key-value memory update.") parser.add_argument("--kv-thread-workers", type=int, default=None, help="Number of thread workers for key-value memory processing.") parser.add_argument("--kv-worker-timeout", type=float, default=None, help="Timeout for key-value memory workers processing.") + parser.add_argument( + "--resident-anchor-summary", + type=str, + default=None, + help="Path to kvreuse_anchor_summary.csv; top selected anchors from this file stay resident on GPU.", + ) + parser.add_argument( + "--resident-anchor-top-n", + type=int, + default=None, + help="Number of hot anchors from --resident-anchor-summary to keep resident on GPU.", + ) + parser.add_argument( + "--test-time", + "--test_time", + nargs="?", + const=True, + default=False, + type=lambda value: str(value).lower() in {"1", "true", "yes", "y", "on"}, + help="Run kv_reuse calls through the dense-vs-KV timing comparison path.", + ) args = parser.parse_args() if len(args.agent_names) != len(args.agent_nums): @@ -100,7 +126,8 @@ async def main(): current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) Time.instance().value = current_time - result_file = output_dir / f"{args.domain}_{args.llm_name}_{current_time}.json" + llm_tag = args.llm_name.replace("/", "__") + result_file = output_dir / f"{args.domain}_{llm_tag}_{current_time}.json" latency_target = str(output_dir) agent_names = [name for name, num in zip(args.agent_names, args.agent_nums) for _ in range(num)] @@ -114,10 +141,24 @@ async def main(): window_size=args.kv_window_size, thread_pool_workers=args.kv_thread_workers, worker_timeout=args.kv_worker_timeout, + resident_anchor_summary=args.resident_anchor_summary, + resident_anchor_top_n=args.resident_anchor_top_n, ) else: kv_config = KVCommConfig.from_env() + logger.info( + "HumanEval run config: execution_mode={} threshold={} max_anchor_num={} window_size={} top_k={} test_time={} resident_anchor_summary={} resident_anchor_top_n={} anchor_device=cpu", + args.execution_mode, + kv_config.threshold, + kv_config.max_anchor_num, + kv_config.window_size, + kv_config.top_k, + args.test_time, + kv_config.resident_anchor_summary, + kv_config.resident_anchor_top_n, + ) + graph = Graph( domain=args.domain, llm_name=args.llm_name, @@ -132,6 +173,13 @@ async def main(): for i_batch in range(num_batches): logger.opt(colors=True).info(f"[BATCH] {i_batch} {'-' * 40}") + reset_cuda_peak_memory() + append_kvcomm_cuda_state_csv( + output_dir, + tag=f"batch_{i_batch}_start", + batch_index=i_batch, + phase="batch_start", + ) start_ts = time.time() current_batch = dataloader(dataset, args.batch_size, i_batch) if not current_batch: @@ -152,7 +200,8 @@ async def main(): if args.execution_mode == "allow_kv_reuse": mode_kwargs = { "prefix": args.prefix, - "output_dir": latency_target + "output_dir": latency_target, + "test_time": args.test_time, } tasks.append( @@ -210,6 +259,13 @@ async def main(): ) logger.opt(colors=True).info(f"[ACCURACY] {accuracy:.4f}") metrics_recorder.log_cumulative(batch_index=i_batch) + append_kvcomm_cuda_state_csv( + output_dir, + tag=f"batch_{i_batch}_end", + batch_index=i_batch, + phase="batch_end", + ) + print_kvcomm_cuda_state(tag=f"batch_{i_batch}_end", topk=10) def get_kwargs( mode: Union[ diff --git a/experiments/run_mmlu.py b/experiments/run_mmlu.py index 3452d4f..f1007c8 100644 --- a/experiments/run_mmlu.py +++ b/experiments/run_mmlu.py @@ -64,6 +64,7 @@ def parse_args(): parser.add_argument("--kv-threshold", type=float, default=None, help="Threshold for key-value memory usage.") parser.add_argument("--kv-max-anchor-num", type=int, default=20, help="Maximum number of anchors for key-value memory.") parser.add_argument("--kv-window-size", type=int, default=None, help="Window size for key-value memory update.") + parser.add_argument("--kv-top-k", type=int, default=None, help="Top-k anchors to activate; overrides top-p selection when set.") parser.add_argument("--kv-thread-workers", type=int, default=None, help="Number of thread workers for key-value memory processing.") parser.add_argument("--kv-worker-timeout", type=float, default=None, help="Timeout for key-value memory workers processing.") @@ -85,6 +86,7 @@ async def main(): threshold=args.kv_threshold, max_anchor_num=args.kv_max_anchor_num, window_size=args.kv_window_size, + top_k=args.kv_top_k, thread_pool_workers=args.kv_thread_workers, worker_timeout=args.kv_worker_timeout, ) @@ -119,7 +121,9 @@ async def main(): **eval_kwargs, ) logger.opt(colors=True).info("[MMLU SCORE] {:.4f}", score) - result_file = output_dir / f"{args.domain}_{args.llm_name}_{timestamp}.json" + # HF model ids contain "/"; must not appear raw in filenames (path separator). + llm_tag = args.llm_name.replace("/", "__") + result_file = output_dir / f"{args.domain}_{llm_tag}_{timestamp}.json" result_file.touch(exist_ok=True) payload = { "score": score,