Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/models/llama/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,14 @@ fbcode_target(_kind = runtime.python_test,
"source_transformation/test_attention_sink.py",
],
supports_static_listing = False,
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
],
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/extension/pybindings:portable_lib",
":export_library",
],
)
Expand Down
11 changes: 9 additions & 2 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,14 @@ def forward(

if self.use_kv_cache:
assert input_pos is not None
if self.enable_dynamic_shape:
is_ring_buffer = getattr(self.kv_cache, "is_ring_buffer", False)

if is_ring_buffer:
# Ring buffer models compute their own mask after KV cache
# update; skip start_pos bounds check since start_pos can
# exceed max_context_len for sliding window / attention sink.
attn_mask = None
elif self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
Expand All @@ -569,7 +576,7 @@ def forward(
)
k, v = self.kv_cache.update(input_pos, k, v)

if getattr(self.kv_cache, "is_ring_buffer", False):
if is_ring_buffer:
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
input_pos[0].item(), seqlen
)
Expand Down
30 changes: 30 additions & 0 deletions examples/models/llama/config/llama_attention_sink.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
base:
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

model:
use_sdpa_with_kv_cache: True
use_kv_cache: True
dtype_override: fp32
enable_dynamic_shape: True
# Attention Sink: "sink_size,window_size"
# sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt)
# window_size=124: sliding window size
# KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252
use_attention_sink: "4,124"

export:
# max_context_length controls the RoPE frequency table size.
# It must be >= sink_size + window_size (128), but larger values are
# recommended to support generation beyond the sliding window.
# The model default (e.g., 8192 or 131072) is typically used if not specified.
# For testing, we use the model's default by not setting this explicitly.

quantization:
qmode: 8da4w
group_size: 128
embedding_quantize: 4,32

backend:
xnnpack:
enabled: True
extended_ops: True
6 changes: 4 additions & 2 deletions examples/models/llama/config/test_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
class TestValidation(unittest.TestCase):
def test_invalid_attention_sink(self):
with self.assertRaises(ValueError):
ModelConfig(use_attention_sink="4,2048")
ModelConfig(use_attention_sink="4")
with self.assertRaises(ValueError):
ModelConfig(use_attention_sink="4,2048,1024")

def test_invalid_local_global_attention_format(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -79,7 +81,7 @@ def test_valid_llm_config(self):
),
model=ModelConfig(
dtype_override="fp32",
use_attention_sink="4,2048,1024",
use_attention_sink="4,2048",
use_kv_cache=True,
local_global_attention="[16, 32]",
),
Expand Down
17 changes: 13 additions & 4 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,28 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
from .source_transformation.attention_sink import enable_attention_sink

attention_sink_params = self.llm_config.model.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
assert len(attention_sink_params) == 2, (
f"use_attention_sink expects exactly 2 comma-separated values "
f"(sink_size,window_size), got {len(attention_sink_params)}"
)
Comment on lines 205 to +209
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changes use_attention_sink from 3 parameters to 2, but other call sites still assume 3 (e.g., examples/models/llama/eval_llama_lib.py asserts len==3 around line ~350, and examples/models/llama/export_llama_lib.py’s CLI help still documents 3 values around line ~594). Please update those to avoid runtime assertion failures / misleading CLI docs.

Copilot uses AI. Check for mistakes.
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.llm_config.export.max_context_length == sink_size + window_size
# max_context_length must be >= sink_size + window_size to have enough RoPE frequencies
# A larger max_context_length is allowed (and recommended) to support generation beyond
# the sliding window size.
assert (
self.llm_config.export.max_context_length >= sink_size + window_size
), (
f"max_context_length ({self.llm_config.export.max_context_length}) must be >= "
f"sink_size + window_size ({sink_size + window_size})"
)

self.model_ = enable_attention_sink(
module=self.model_,
params=model_args,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)

missing, unexpected = None, None
Expand Down
Loading
Loading