From 973c88b6a50b48996c4aab5a8fe6c9006ea1f2f7 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 10 Apr 2026 12:04:25 -0700 Subject: [PATCH] Rewrite attention sink from eviction to ring buffer (#18821) Summary: Replace the eviction-based attention sink implementation with a torch.export compatible ring buffer approach, and rewrite all tests. Key changes: - RopeWithAttentionSink: simplified to pass through original positions (no more position shifting or k re-rotation) - KVCacheWithAttentionSink: uses ring buffer with index_copy_ instead of dynamic eviction (torch.cat/narrow/shift). Cache layout: [sink slots | ring buffer]. Sets is_ring_buffer=True so AttentionMHA.forward handles masking natively. - CachePositionsManagerWithSink: new module that maps positions to cache indices, with sink tokens in fixed slots and window tokens in ring buffer region. - AttentionMHA.forward: ring buffer models skip start_pos bounds check and compute their own causal mask after KV cache update. - Remove eviction_batch_size from all interfaces (no longer needed). - Remove attention_sink_forward monkey-patch and rerotate_k dead code. - Add llama_attention_sink.yaml example config. - Rewrite 16 eviction-based tests with 18 ring buffer tests covering sink preservation, ring wrapping, causal masking, and degenerate cases. Differential Revision: D100216687 --- examples/models/llama/BUCK | 5 + examples/models/llama/attention.py | 11 +- .../llama/config/llama_attention_sink.yaml | 30 + .../models/llama/config/test_llm_config.py | 6 +- examples/models/llama/model.py | 17 +- .../source_transformation/attention_sink.py | 388 +++++----- .../test_attention_sink.py | 725 ++++++++---------- extension/llm/export/config/llm_config.py | 4 +- 8 files changed, 575 insertions(+), 611 deletions(-) create mode 100644 examples/models/llama/config/llama_attention_sink.yaml diff --git a/examples/models/llama/BUCK b/examples/models/llama/BUCK index bd3c23a492e..74e3fec2445 100644 --- a/examples/models/llama/BUCK +++ b/examples/models/llama/BUCK @@ -278,9 +278,14 @@ fbcode_target(_kind = runtime.python_test, "source_transformation/test_attention_sink.py", ], supports_static_listing = False, + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/extension/llm/custom_ops:custom_ops_aot_py", + ], deps = [ "fbsource//third-party/pypi/parameterized:parameterized", "//caffe2:torch", + "//executorch/extension/pybindings:portable_lib", ":export_library", ], ) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 1cb7ba866b7..d6dff173072 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -550,7 +550,14 @@ def forward( if self.use_kv_cache: assert input_pos is not None - if self.enable_dynamic_shape: + is_ring_buffer = getattr(self.kv_cache, "is_ring_buffer", False) + + if is_ring_buffer: + # Ring buffer models compute their own mask after KV cache + # update; skip start_pos bounds check since start_pos can + # exceed max_context_len for sliding window / attention sink. + attn_mask = None + elif self.enable_dynamic_shape: start_pos = input_pos[-1].item() torch._check_is_size(start_pos) torch._check(start_pos < self.max_context_len) @@ -569,7 +576,7 @@ def forward( ) k, v = self.kv_cache.update(input_pos, k, v) - if getattr(self.kv_cache, "is_ring_buffer", False): + if is_ring_buffer: attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( input_pos[0].item(), seqlen ) diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml new file mode 100644 index 00000000000..81b34d84457 --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -0,0 +1,30 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size" + # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) + # window_size=124: sliding window size + # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + use_attention_sink: "4,124" + +export: + # max_context_length controls the RoPE frequency table size. + # It must be >= sink_size + window_size (128), but larger values are + # recommended to support generation beyond the sliding window. + # The model default (e.g., 8192 or 131072) is typically used if not specified. + # For testing, we use the model's default by not setting this explicitly. + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/config/test_llm_config.py b/examples/models/llama/config/test_llm_config.py index ec85e4c2e92..c5823eb3097 100644 --- a/examples/models/llama/config/test_llm_config.py +++ b/examples/models/llama/config/test_llm_config.py @@ -25,7 +25,9 @@ class TestValidation(unittest.TestCase): def test_invalid_attention_sink(self): with self.assertRaises(ValueError): - ModelConfig(use_attention_sink="4,2048") + ModelConfig(use_attention_sink="4") + with self.assertRaises(ValueError): + ModelConfig(use_attention_sink="4,2048,1024") def test_invalid_local_global_attention_format(self): with self.assertRaises(ValueError): @@ -79,7 +81,7 @@ def test_valid_llm_config(self): ), model=ModelConfig( dtype_override="fp32", - use_attention_sink="4,2048,1024", + use_attention_sink="4,2048", use_kv_cache=True, local_global_attention="[16, 32]", ), diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 8b35d7d3155..f02621b66b2 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -203,19 +203,28 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): from .source_transformation.attention_sink import enable_attention_sink attention_sink_params = self.llm_config.model.use_attention_sink.split(",") - assert len(attention_sink_params) == 3 + assert len(attention_sink_params) == 2, ( + f"use_attention_sink expects exactly 2 comma-separated values " + f"(sink_size,window_size), got {len(attention_sink_params)}" + ) sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) - eviction_batch_size = int(attention_sink_params[2]) - assert self.llm_config.export.max_context_length == sink_size + window_size + # max_context_length must be >= sink_size + window_size to have enough RoPE frequencies + # A larger max_context_length is allowed (and recommended) to support generation beyond + # the sliding window size. + assert ( + self.llm_config.export.max_context_length >= sink_size + window_size + ), ( + f"max_context_length ({self.llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) self.model_ = enable_attention_sink( module=self.model_, params=model_args, sink_size=sink_size, window_size=window_size, - eviction_batch_size=eviction_batch_size, ) missing, unexpected = None, None diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 22bd8a3e228..c981ae339f3 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,26 +7,35 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. -import types -from typing import Optional +# This implementation is torch.export compatible using a ring buffer approach +# for the sliding window portion while preserving the sink tokens. -import torch +from typing import Optional, Tuple -from executorch.examples.models.llama.attention import AttentionMHA, KVCache -from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.rope import ( - apply_rotary_emb_to_k, - hf_apply_rotary_emb_to_k, - Rope, +import torch +import torch.nn as nn +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + AttentionMHA, + KVCache, + RingKVCache, ) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter class RopeWithAttentionSink(Rope): """ - Rope that helps adjust position encoding when tokens are shifted in KVCache. - For AttentionSink, when tokens are shifted in KVCache, we need to use positions - in KVCache instead of positions in the actual text. + Rope subclass for Attention Sink models. + + For torch.export compatibility, this passes through the original position + unchanged - the sliding window is handled by the cache index management + (ring buffer), not by position shifting. + + Note: This class uses the model's max_context_len (params.max_context_len) for + RoPE frequency table size, which should be large enough to support generation + beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( @@ -34,77 +43,123 @@ def __init__( params: ModelArgs, window_size: int, sink_size: int, - eviction_batch_size: int, ): super().__init__(params) - if self.params.use_hf_rope: - self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k - else: - self.apply_rotary_emb_to_k = apply_rotary_emb_to_k - self.max_context_length = window_size + sink_size - assert self.max_context_length == self.params.max_context_len - self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 + self.window_size = window_size + self.sink_size = sink_size + # max_context_len from params is used for RoPE frequencies (should be large) + self.max_context_length = self.params.max_context_len def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get rotary embedding frequencies. + For attention sink, we use the original position - the sliding window + is handled by the cache index management, not by position shifting. + """ assert input_pos is not None + # Use torch._check for export compatibility (data-dependent guard) + torch._check(input_pos[0].item() + seq_len <= self.max_context_length) + return super().get_freqs(input_pos, seq_len) - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return super().get_freqs(input_pos + self.position_shift, seq_len) - def rerotate_k( - self, - k: torch.Tensor, - original_position: int, - new_position: int, - ): - """ - Rerotate k from original_position to new_position. This is done by rerotating - k with (new_position * theta - original_position * theta) with the following matrix: - (cos(delta), -sin(delta) - sin(delta), cos(delta)) - where delta = new_position * theta - original_position * theta +def _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len +): + """ + Create causal mask for attention sink. + + Unlike regular ring buffer mask, this mask: + 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) + 2. Uses sliding window for other tokens + + Args: + cache_positions: Tensor of actual positions stored at each cache index + window_size: Size of the sliding window + sink_size: Number of sink tokens to always attend to + start_pos: Starting position of the current query + seq_len: Length of the current query sequence + """ + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + + # Valid if position is filled (>= 0) and causal (delta >= 0) + is_valid = (cache_positions >= 0) & (delta >= 0) + + # Sink tokens (original positions 0 to sink_size-1) are always visible + is_sink = cache_positions < sink_size + + # Window tokens must be within sliding window + is_in_window = delta < window_size + + # Final mask: valid AND (is_sink OR is_in_window) + attn_mask = is_valid & (is_sink | is_in_window) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + +class CachePositionsManagerWithSink(nn.Module): + """ + Manages cache positions for attention sink + sliding window. + + For sink_size=0: behaves exactly like original CachePositionsManager. + For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. + + IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). + """ + + def __init__(self, cache_size: int, sink_size: int = 0): + super().__init__() + # cache_size is the actual size of the kv cache dimension + self.max_context_length = cache_size + self.sink_size = sink_size + self.ring_size = cache_size - sink_size + # Initialize to -1 to indicate empty/unfilled slots + self.register_buffer( + "cache_positions", + torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), + ) - The shape of k is (batch_size, seq_len, n_local_heads, head_dim) + def calculate_positions_and_update_indices( + self, input_pos: torch.Tensor, seq_len: int + ) -> torch.Tensor: + """ + Calculate indices into k_cache, v_cache for placing k_val, v_val. - Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961 + Sink tokens (positions < sink_size) map to cache slots [0, sink_size). + Window tokens (positions >= sink_size) use ring buffer in [sink_size, cache_size). """ - seq_len = k.shape[1] - original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) - original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) - new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) - new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) - rerotation_cos = ( - new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + + # Sink tokens go to fixed slots; window tokens use ring buffer + indices = torch.where( + orig_indices < self.sink_size, + orig_indices, + self.sink_size + (orig_indices - self.sink_size) % self.ring_size, ) - rerotation_sin = ( - new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin + + # Update cache_positions exactly like original CachePositionsManager + full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) + arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) + cache_positions = torch.where( + arange_tensor < start_pos, self.cache_positions, full_t ) + self.cache_positions.copy_(cache_positions) + self.cache_positions.index_copy_(0, indices, orig_indices) - return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) + return indices class KVCacheWithAttentionSink(KVCache): """ - KV cache that supports attention sink. It keeps the initial few tokens as attention sink. - For other tokens, it uses a sliding window to keep the most recent tokens. + KV cache that supports attention sink with torch.export compatibility. + + Uses a ring buffer approach for the sliding window portion while keeping + the first sink_size tokens fixed. This avoids dynamic shape operations. - Parameters: - window_size: the size of the sliding window - sink_size: the number of initial tokens to keep as attention sink - eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache + Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( @@ -115,13 +170,14 @@ def __init__( rope: RopeWithAttentionSink, window_size: int, sink_size: int, - eviction_batch_size: int, max_batch_size: int = 1, dtype=torch.float32, ): + # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) + total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, - max_context_length=window_size + sink_size, + max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, @@ -130,108 +186,74 @@ def __init__( self.rope = rope self.window_size = window_size self.sink_size = sink_size - self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 + self.is_ring_buffer = True - def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: - """ - Evict old tokens from the cache to make rooms for new tokens. - - Parameters: - input_pos: the start position of the incoming token in the actual sequence - seq_len: the length of the incoming sequence - rope: the rope object to use for rerotating k + # Cache positions manager for determining write locations + # Pass the total cache size (same as self.max_context_length after super().__init__) + self.cache_positions_manager = CachePositionsManagerWithSink( + total_cache_size, sink_size + ) - Returns: - the number of tokens to evict from the cache which is also the number of - positions to shift for incoming tokens + def create_causal_mask_for_ring_buffer(self, start_pos: int, seq_len: int): """ - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - num_to_keep = ( - input_pos_item + self.position_shift - self.sink_size - num_to_evict - ) - num_empty_space = self.window_size - num_to_keep - dim_to_slice = 2 - k_to_keep = self.k_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ) - k_to_keep = self.rope.rerotate_k( - k=k_to_keep.transpose(1, 2), - original_position=(self.sink_size + num_to_evict), # pyre-ignore [6] - new_position=self.sink_size, - ).transpose(1, 2) - self.k_cache = torch.cat( - [ - self.k_cache.narrow(dim_to_slice, 0, self.sink_size), - k_to_keep, - torch.zeros_like( - self.k_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, + Create causal mask for the attention with attention sink. + Sink tokens are ALWAYS visible, plus recent tokens in the window. + """ + cache_positions = self.cache_positions_manager.cache_positions + if self.sink_size > 0: + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) - self.v_cache = torch.cat( - [ - self.v_cache.narrow(dim_to_slice, 0, self.sink_size), - self.v_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ), - torch.zeros_like( - self.v_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, + else: + # Pure ring buffer mode - use original mask with window_size = actual window + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return self.position_shift + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV cache with new key-value pairs. + Uses ring buffer indexing for positions >= sink_size. + """ + seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + # Verify that window tokens (those mapping to the ring buffer) don't + # exceed ring_size, which would cause duplicate indices in index_copy_. + # Sink tokens (positions < sink_size) map to fixed slots and are safe. + start_pos = input_pos[0].item() + num_sink_tokens = max(0, min(seq_len, self.sink_size - start_pos)) + num_window_tokens = seq_len - num_sink_tokens + assert num_window_tokens <= self.cache_positions_manager.ring_size, ( + f"Window tokens ({num_window_tokens}) exceed ring buffer capacity " + f"({self.cache_positions_manager.ring_size}), which would cause " + f"non-deterministic behavior with index_copy_" + ) -def attention_sink_forward( - self, - x: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, -): - assert self.use_kv_cache - assert input_pos is not None - - bsz, seqlen, _ = x.shape - - # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + # Calculate write indices + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) - # Prepare for space in KV cache and get position shift - position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) - # RoPE relative positional embeddings with shifted position in KV cache - q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + return self.k_cache, self.v_cache - output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + """ + For ring buffer implementation, no explicit eviction is needed. + The ring buffer automatically overwrites old values. + Returns 0 to indicate no position shift is needed. + """ + return 0 def _replace_rope( @@ -251,7 +273,6 @@ def _replace_attention( rope_with_attention_sink: RopeWithAttentionSink, sink_size: int, window_size: int, - eviction_batch_size: int, ): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] @@ -260,26 +281,34 @@ def _replace_attention( rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, - eviction_batch_size=eviction_batch_size, ) if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache - kv_cache_with_attention_sink = KVCacheWithAttentionSink( - n_heads=kv_cache.n_heads, - head_dim=kv_cache.head_dim, - enable_dynamic_shape=kv_cache.enable_dynamic_shape, - rope=rope_with_attention_sink, - max_batch_size=kv_cache.max_batch_size, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=kv_cache.k_cache.dtype, - ) - child_module.kv_cache = kv_cache_with_attention_sink - child_module.forward = types.MethodType( # pyre-ignore - attention_sink_forward, child_module - ) + if sink_size == 0: + # No sink tokens needed — use standard RingKVCache directly + child_module.kv_cache = RingKVCache( + kv_cache.max_batch_size, + window_size, # RingKVCache expects user-provided window size + kv_cache.n_heads, + kv_cache.head_dim, + kv_cache.enable_dynamic_shape, + kv_cache.k_cache.dtype, + ) + else: + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + # Don't replace forward - let the original AttentionMHA.forward handle it + # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask def enable_attention_sink( @@ -287,19 +316,17 @@ def enable_attention_sink( params: ModelArgs, sink_size: int, window_size: int, - eviction_batch_size: int, ) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. - There mainly three steps: + There mainly two steps: - Replace Rope with RopeWithAttentionSink - - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward + - Replace Attention's KVCache with KVCacheWithAttentionSink """ rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, - eviction_batch_size=eviction_batch_size, ) _replace_rope(module, rope_with_attention_sink) _replace_attention( @@ -307,6 +334,5 @@ def enable_attention_sink( rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, - eviction_batch_size=eviction_batch_size, ) return module diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index fc882ebf4ab..51474d75969 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -10,6 +10,7 @@ from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, KVCacheWithAttentionSink, RopeWithAttentionSink, ) @@ -18,497 +19,381 @@ class RopeWithAttentionSinkTest(unittest.TestCase): - def _init_rope(self, params: ModelArgs, eviction_batch_size: int): - return RopeWithAttentionSink( - params=params, - window_size=252, - sink_size=4, - eviction_batch_size=eviction_batch_size, - ) - def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 ) - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=1 + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=124, + sink_size=4, ) @parameterized.expand( [ - [0, 10, 1, 0], # No shift - [250, 10, 1, 246], # Some shift - [256, 10, 1, 246], # All shift - [0, 10, 30, 0], # No shift with batch eviction - [250, 10, 30, 220], # Some shift with batch eviction - [256, 10, 30, 226], # All shift with batch eviction + [0, 10], + [50, 10], + [200, 10], + [0, 1], + [100, 5], ] ) - def test_get_freqs( - self, input_pos, seq_len, eviction_batch_size, expected_result_pos - ): - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=eviction_batch_size - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + def test_get_freqs_passthrough(self, input_pos, seq_len): + """get_freqs should return frequencies for the exact input position (no shifting).""" + freqs_cos, freqs_sin = self.rope.get_freqs( input_pos=torch.tensor([input_pos], dtype=torch.int32), seq_len=seq_len, ) - torch.testing.assert_close( - freqs_cos, - self.rope_with_attention_sink.freqs_cos.narrow( - 0, expected_result_pos, seq_len - ), - ) - torch.testing.assert_close( - freqs_sin, - self.rope_with_attention_sink.freqs_sin.narrow( - 0, expected_result_pos, seq_len - ), - ) + expected_cos = self.rope.freqs_cos.narrow(0, input_pos, seq_len) + expected_sin = self.rope.freqs_sin.narrow(0, input_pos, seq_len) - @parameterized.expand( - [ - [128, 127], # Rotate left - [128, 128], # No rotation - [128, 129], # Rotate right - ] - ) - def test_rotate(self, original_position, new_position): - seq_len = 32 - - size = (1, seq_len, self.params.n_heads, self.params.head_dim) - q = torch.rand(*size, dtype=torch.float32) - k = torch.rand( - *size, - dtype=torch.float32, - ) - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([original_position], dtype=torch.int32), - seq_len=seq_len, + torch.testing.assert_close(freqs_cos, expected_cos) + torch.testing.assert_close(freqs_sin, expected_sin) + + +class CachePositionsManagerWithSinkTest(unittest.TestCase): + + def test_sink_indices_fixed(self): + """Positions < sink_size should map to themselves (fixed slots).""" + manager = CachePositionsManagerWithSink(cache_size=12, sink_size=4) + # Fill sink tokens: positions 0,1,2,3 + indices = manager.calculate_positions_and_update_indices( + torch.tensor([0], dtype=torch.long), seq_len=4 ) - _, pre_rotated_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, + self.assertEqual(indices.tolist(), [0, 1, 2, 3]) + + def test_window_indices_ring_buffer(self): + """Positions >= sink_size should use ring buffer in [sink_size, cache_size).""" + manager = CachePositionsManagerWithSink(cache_size=12, sink_size=4) + # ring_size = 12 - 4 = 8 + # Position 4 -> slot 4, position 5 -> slot 5, etc. + indices = manager.calculate_positions_and_update_indices( + torch.tensor([4], dtype=torch.long), seq_len=3 ) + self.assertEqual(indices.tolist(), [4, 5, 6]) - rerotated_k = self.rope_with_attention_sink.rerotate_k( - k=pre_rotated_k, - original_position=original_position, - new_position=new_position, + def test_window_wraps_around(self): + """Window tokens should wrap around in the ring buffer region.""" + manager = CachePositionsManagerWithSink(cache_size=12, sink_size=4) + # ring_size = 8, positions 12..14 -> (12-4)%8=0 -> slot 4, slot 5, slot 6 + indices = manager.calculate_positions_and_update_indices( + torch.tensor([12], dtype=torch.long), seq_len=3 ) + self.assertEqual(indices.tolist(), [4, 5, 6]) - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([new_position], dtype=torch.int32), - seq_len=seq_len, + def test_sink_never_overwritten(self): + """After wrapping, sink slots (0..sink_size-1) should retain original positions.""" + manager = CachePositionsManagerWithSink(cache_size=12, sink_size=4) + # Fill sink + some window + manager.calculate_positions_and_update_indices( + torch.tensor([0], dtype=torch.long), seq_len=10 ) - _, expected_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, + # Wrap around: position 12 maps to slot 4 + manager.calculate_positions_and_update_indices( + torch.tensor([12], dtype=torch.long), seq_len=3 ) + # Sink positions should still show 0,1,2,3 + self.assertEqual(manager.cache_positions[:4].tolist(), [0, 1, 2, 3]) - torch.testing.assert_close(rerotated_k, expected_k) + def test_cache_positions_updated(self): + """cache_positions should track the actual position stored at each slot.""" + manager = CachePositionsManagerWithSink(cache_size=8, sink_size=2) + # ring_size = 6 + # Fill positions 0..7 + manager.calculate_positions_and_update_indices( + torch.tensor([0], dtype=torch.long), seq_len=8 + ) + self.assertEqual(manager.cache_positions.tolist(), [0, 1, 2, 3, 4, 5, 6, 7]) + # Position 8 wraps to slot 2 (sink_size + (8-2)%6 = 2) + manager.calculate_positions_and_update_indices( + torch.tensor([8], dtype=torch.long), seq_len=1 + ) + self.assertEqual(manager.cache_positions.tolist(), [0, 1, 8, 3, 4, 5, 6, 7]) class KVCacheWithAttentionSinkTest(unittest.TestCase): - _single_evict_test_cases = [ - [4, 1], - ] - - _batch_evict_test_cases = [ - [4, 8], - ] - - _sliding_window_test_cases = [ - [0, 1], - ] - - def _init_cache(self, sink_size, eviction_batch_size): + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.window_size = 28 + self.sink_size = 4 + self.dtype = torch.float32 self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, - max_context_len=self.window_size + sink_size, + max_context_len=256, ) - self.rope_with_attention_sink = RopeWithAttentionSink( + self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, + sink_size=self.sink_size, ) + # Total cache size = sink_size + window_size * 2 = 4 + 56 = 60 + self.cache_size = self.sink_size + self.window_size * 2 self.kv_cache = KVCacheWithAttentionSink( n_heads=self.params.n_heads, head_dim=self.params.head_dim, enable_dynamic_shape=self.params.enable_dynamic_shape, - rope=self.rope_with_attention_sink, + rope=self.rope, max_batch_size=self.max_batch_size, window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=self.dtype, - ) - - def _rand_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) - return k, v - - def _zero_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, + sink_size=self.sink_size, dtype=self.dtype, ) - return k, v - - def _get_dim_to_slice(self): - return 2 - - def _get_expected_rotated_k(self, k, original_position, new_position): - return self.rope_with_attention_sink.rerotate_k( - k=k.transpose(1, 2), - original_position=original_position, - new_position=new_position, - ).transpose(1, 2) - - def setUp(self): - torch.manual_seed(42) - self.max_batch_size = 1 - self.window_size = 28 - self.dtype = torch.float32 - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_empty_cache(self, sink_size, eviction_batch_size): - self._init_cache(sink_size, eviction_batch_size) - - # KV cache is empty, evict does nothing - input_pos = torch.tensor([0], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_without_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = 2 - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has enough spaces for new tokens, no shift - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(10) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) - - expected_k = torch.cat( - [ - k, - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v, - zero_v, - ], - dim=dimension_to_slice, - ) - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + def _rand_kv(self, seq_len): + size = (self.max_batch_size, self.params.n_heads, seq_len, self.params.head_dim) + return torch.rand(*size, dtype=self.dtype), torch.rand(*size, dtype=self.dtype) - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_some_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens + def test_evict_tokens_returns_zero(self): + """Ring buffer implementation needs no eviction; evict_tokens always returns 0.""" input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) + self.assertEqual(self.kv_cache.evict_tokens(input_pos, 1), 0) - self.kv_cache.update(input_pos, k1, v1) + input_pos = torch.tensor([100], dtype=torch.int32) + self.assertEqual(self.kv_cache.evict_tokens(input_pos, 10), 0) - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 24) == -2 + def test_update_initial_fill(self): + """First tokens should fill cache slots sequentially.""" + k, v = self._rand_kv(10) + input_pos = torch.tensor([0], dtype=torch.long) + k_out, v_out = self.kv_cache.update(input_pos, k, v) - zero_k, zero_v = self._zero_kv_with_length(24) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 1, 4), - zero_v, - ], - dim=dimension_to_slice, + # Slots 0..9 should contain our data + torch.testing.assert_close(k_out[:, :, :10, :], k) + torch.testing.assert_close(v_out[:, :, :10, :], v) + # Remaining slots should be zeros + torch.testing.assert_close( + k_out[:, :, 10:, :], + torch.zeros_like(k_out[:, :, 10:, :]), ) - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + def test_sink_tokens_preserved_after_wrap(self): + """Sink tokens (positions 0..sink_size-1) must never be overwritten.""" + # Fill entire cache + k_init, v_init = self._rand_kv(self.cache_size) + input_pos = torch.tensor([0], dtype=torch.long) + self.kv_cache.update(input_pos, k_init, v_init) - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_all_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() + sink_k = k_init[:, :, : self.sink_size, :].clone() + sink_v = v_init[:, :, : self.sink_size, :].clone() - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) + # Write past the cache size — should wrap in window region only + k_new, v_new = self._rand_kv(5) + input_pos = torch.tensor([self.cache_size], dtype=torch.long) + k_out, v_out = self.kv_cache.update(input_pos, k_new, v_new) - self.kv_cache.update(input_pos, k, v) + # Sink tokens must be unchanged + torch.testing.assert_close(k_out[:, :, : self.sink_size, :], sink_k) + torch.testing.assert_close(v_out[:, :, : self.sink_size, :], sink_v) - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(27) + def test_ring_buffer_wrapping(self): + """Window tokens should wrap correctly in the ring buffer region.""" + ring_size = self.cache_size - self.sink_size # 56 - self.kv_cache.update(input_pos, k1, v1) + # Fill cache initially + k_init, v_init = self._rand_kv(self.cache_size) + self.kv_cache.update(torch.tensor([0], dtype=torch.long), k_init, v_init) - input_pos = torch.tensor([32], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 5, 22), 10, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 5, 22), - zero_v, - ], - dim=dimension_to_slice, + # Write at position that wraps: pos = sink_size + ring_size = 4 + 56 = 60 + # This should map to slot sink_size + (60-4)%56 = 4 + 0 = slot 4 + k_wrap, v_wrap = self._rand_kv(3) + self.kv_cache.update( + torch.tensor([self.sink_size + ring_size], dtype=torch.long), + k_wrap, + v_wrap, ) - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_some_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 20) == -2 - - zero_k, zero_v = self._zero_kv_with_length(20) - expected_k = torch.cat( - [ - self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), - self._get_expected_rotated_k(k1, 5, 3), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 2, 3), - v1, - zero_v, - ], - dim=dimension_to_slice, + # Slots 4,5,6 should now have the new data + k_out = self.kv_cache.k_cache + torch.testing.assert_close( + k_out[:, :, self.sink_size : self.sink_size + 3, :], k_wrap ) - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_all_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) + def test_sequential_generation(self): + """Simulate sequential token generation and verify sink protection.""" + # Prefill 10 tokens + k_prefill, v_prefill = self._rand_kv(10) + self.kv_cache.update(torch.tensor([0], dtype=torch.long), k_prefill, v_prefill) - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) + sink_k = k_prefill[:, :, : self.sink_size, :].clone() - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(23) + # Generate tokens one by one, well past cache size + for pos in range(10, self.cache_size + 20): + k_tok, v_tok = self._rand_kv(1) + self.kv_cache.update(torch.tensor([pos], dtype=torch.long), k_tok, v_tok) - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([28], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + # Sink tokens must still be the original ones + torch.testing.assert_close( + self.kv_cache.k_cache[:, :, : self.sink_size, :], sink_k + ) + + def test_causal_mask_attends_to_sink(self): + """The causal mask should always allow attending to sink tokens.""" + # Fill some tokens + k, v = self._rand_kv(20) + self.kv_cache.update(torch.tensor([0], dtype=torch.long), k, v) + + # Get mask for position 15 + mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos=15, seq_len=1) + + # Sink slots (0..3) should be attended to (mask value = 0, not -inf) + for i in range(self.sink_size): + self.assertEqual( + mask[0, i].item(), + 0.0, + f"Sink slot {i} should be attendable", + ) + + def test_causal_mask_blocks_future(self): + """The causal mask should block future (unfilled) positions.""" + # Fill only 5 tokens + k, v = self._rand_kv(5) + self.kv_cache.update(torch.tensor([0], dtype=torch.long), k, v) + + mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos=4, seq_len=1) + + # Unfilled slots should be masked (-inf) + for i in range(5, self.cache_size): + self.assertEqual( + mask[0, i].item(), + float("-inf"), + f"Unfilled slot {i} should be masked", + ) - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 1, 22), 6, 0 - ), - zero_k, - ], - dim=dimension_to_slice, + @parameterized.expand( + [ + [0], # No sink, pure sliding window + ] + ) + def test_no_sink_degenerates_to_ring_buffer(self, sink_size): + """With sink_size=0, behavior should match a plain ring buffer.""" + params = ModelArgs( + use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 ) - expected_v = torch.cat( - [ - v1.narrow(dimension_to_slice, 1, 22), - zero_v, - ], - dim=dimension_to_slice, + rope = RopeWithAttentionSink( + params=params, window_size=self.window_size, sink_size=0 ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 12) == -10 - - zero_k, zero_v = self._zero_kv_with_length(12) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 9, 16), 14, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 9, 16), - zero_v, - ], - dim=dimension_to_slice, + cache = KVCacheWithAttentionSink( + n_heads=params.n_heads, + head_dim=params.head_dim, + enable_dynamic_shape=params.enable_dynamic_shape, + rope=rope, + max_batch_size=1, + window_size=self.window_size, + sink_size=0, + dtype=self.dtype, ) + cache_size = self.window_size * 2 # 56 - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + # Fill and wrap + k_init, v_init = self._rand_kv(cache_size) + cache.update(torch.tensor([0], dtype=torch.long), k_init, v_init) - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() + k_new, v_new = self._rand_kv(3) + cache.update(torch.tensor([cache_size], dtype=torch.long), k_new, v_new) - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) + # Slot 0,1,2 should have new data (no sink protection) + torch.testing.assert_close(cache.k_cache[:, :, :3, :], k_new) - self.kv_cache.update(input_pos, k, v) - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) +class AttentionSinkE2ETest(unittest.TestCase): + """ + End-to-end test: construct a full Transformer with attention sink, + optionally with custom SDPA + custom KV cache, and generate tokens + beyond the context window size. + """ - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -8 - - zero_k, zero_v = self._zero_kv_with_length(10) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 7, 18), 12, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 7, 18), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + def _make_args(self, max_context_len=128): + return ModelArgs( + dim=64, + n_heads=4, + n_kv_heads=2, + head_dim=16, + hidden_dim=128, + max_batch_size=1, + max_seq_len=32, + max_context_len=max_context_len, + use_kv_cache=True, + enable_dynamic_shape=True, + n_layers=2, + vocab_size=32, + ) + + def _build_model(self, args, sink_size, window_size, use_custom_sdpa=False): + from executorch.examples.models.llama.llama_transformer import ( + construct_transformer, + ) + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + model = construct_transformer(args) + model = enable_attention_sink( + model, params=args, sink_size=sink_size, window_size=window_size + ) + + if use_custom_sdpa: + from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, + ) + from executorch.examples.models.llama.source_transformation.sdpa import ( + replace_sdpa_with_custom_op, + ) + + try: + replace_sdpa_with_custom_op(model) + except ImportError: + raise unittest.SkipTest( + "Custom SDPA ops not available (missing pybindings)" + ) + replace_kv_cache_with_custom_kv_cache(model) + + model.eval() + return model + + def _run_generation(self, model, args, num_tokens): + """Run prefill + decode for num_tokens total, return all outputs.""" + outputs = [] + with torch.no_grad(): + # Prefill with 4 tokens + prefill_tokens = torch.randint(0, args.vocab_size, (1, 4)) + result = model( + tokens=prefill_tokens, + attn_options={"input_pos": torch.tensor([0], dtype=torch.long)}, + ) + out = result[0] if isinstance(result, tuple) else result + outputs.append(out) + + # Decode one token at a time + for pos in range(4, num_tokens): + token = torch.randint(0, args.vocab_size, (1, 1)) + result = model( + tokens=token, + attn_options={"input_pos": torch.tensor([pos], dtype=torch.long)}, + ) + out = result[0] if isinstance(result, tuple) else result + outputs.append(out) + + return outputs + + def test_beyond_context_window_basic(self): + """Generate tokens well beyond the KV cache size using standard SDPA.""" + sink_size = 4 + window_size = 16 + # KV cache size = sink_size + window_size * 2 = 36 + # max_context_len = 128 (for RoPE table) + args = self._make_args(max_context_len=128) + model = self._build_model(args, sink_size, window_size, use_custom_sdpa=False) + + # Generate 80 tokens — well beyond KV cache size of 36 + outputs = self._run_generation(model, args, num_tokens=80) + + self.assertEqual(len(outputs), 77) # 1 prefill + 76 decode steps + for out in outputs: + self.assertTrue( + torch.isfinite(out).all(), "Output contains non-finite values" + ) diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index f4bdfbf1a0d..4d587854c7f 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -218,9 +218,9 @@ def __post_init__(self): def _validate_attention_sink(self): if self.use_attention_sink: attention_sink_params = self.use_attention_sink.split(",") - if len(attention_sink_params) != 3: + if len(attention_sink_params) != 2: raise ValueError( - "The value of use_attention_sink must be structured like ',,'" + "The value of use_attention_sink must be structured like ','" )