diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index 657c79e0c4c..3e4d3bec408 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -102,6 +102,26 @@ def update( # Config +def validate_eagle_tap_layers(layers: list, num_hidden_layers: int) -> None: + """Validate EAGLE-3 tap indices (HF/vLLM convention; 0 = embedding output). + + Indices must be non-bool ints, unique, ascending, and in + ``[0, num_hidden_layers]``. Order defines the fc concatenation order. + """ + if not layers: + return + if any(isinstance(t, bool) or not isinstance(t, int) for t in layers): + raise ValueError(f"eagle_tap_layers must be non-bool ints, got {layers}") + if len(set(layers)) != len(layers): + raise ValueError(f"eagle_tap_layers has duplicates: {layers}") + if any(t < 0 or t > num_hidden_layers for t in layers): + raise ValueError( + f"eagle_tap_layers {layers} out of range [0, {num_hidden_layers}]" + ) + if list(layers) != sorted(layers): + raise ValueError(f"eagle_tap_layers must be ascending (fc order): {layers}") + + @dataclass class Gemma4_31BConfig: # Embedding / shape @@ -144,6 +164,11 @@ class Gemma4_31BConfig: # Runtime max_seq_len: int = 4096 + # EAGLE-3 auxiliary hidden-state taps. Indices use the HF/vLLM convention: + # 0 = embedding output, k = output after decoder layer k-1. Empty disables + # tap collection. + eagle_tap_layers: list = field(default_factory=list) + def __post_init__(self): if not self.layer_types: # Default hybrid pattern: 5 sliding then 1 full, repeated. @@ -156,6 +181,7 @@ def __post_init__(self): f"layer_types length {len(self.layer_types)} != " f"num_hidden_layers {self.num_hidden_layers}" ) + validate_eagle_tap_layers(self.eagle_tap_layers, self.num_hidden_layers) @staticmethod def from_hf_config(config_path: str) -> "Gemma4_31BConfig": @@ -466,6 +492,48 @@ def _build_masks( return sliding_mask, full_mask + def _decode( + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + collect_taps: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Embed -> decoder layers -> final norm. + + Returns the normed hidden states and, when ``collect_taps`` is set, the + concatenated tap features for ``config.eagle_tap_layers`` (in ascending + index order) as ``(B, T, len(tap_layers) * hidden_size)``; else None. + + Tap indices follow the HF/vLLM hidden-state convention: index 0 is the + embedding output (before any decoder layer) and index k is the output + *after* decoder layer k-1. + """ + x = self.embed_tokens(tokens) * self.embed_normalizer + + tap_layers = self.config.eagle_tap_layers + if collect_taps: + # Revalidate dynamic tap configuration before membership checks. + validate_eagle_tap_layers(tap_layers, len(self.layers)) + taps = [] + if collect_taps and 0 in tap_layers: + taps.append(x) # index 0 == embedding output + + sliding_mask, full_mask = self._build_masks(input_pos) + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, sliding_mask, full_mask) + if collect_taps and (i + 1) in tap_layers: + taps.append(x) # output of layer i == hidden-state index i+1 + + if collect_taps and len(taps) != len(tap_layers): + raise ValueError( + f"collected {len(taps)} taps but eagle_tap_layers requests " + f"{len(tap_layers)} ({tap_layers}); check the index convention" + ) + + x = self.norm(x) + taps_out = torch.cat(taps, dim=-1) if taps else None + return x, taps_out + def forward( self, tokens: torch.LongTensor, @@ -482,18 +550,41 @@ def forward( Returns: (B, 1) sampled token IDs as float. """ - x = self.embed_tokens(tokens) * self.embed_normalizer - - sliding_mask, full_mask = self._build_masks(input_pos) - for layer in self.layers: - x = layer(x, input_pos, sliding_mask, full_mask) - - x = self.norm(x) + x, _ = self._decode(tokens, input_pos, collect_taps=False) last = self.lm_head(x[:, -1, :]).float() cap = self.logit_softcap.float() last = torch.tanh(last / cap) * cap return sample(last, temperature) + def set_eagle_tap_layers(self, layers: list) -> None: + """Set and validate EAGLE-3 tap layers.""" + validate_eagle_tap_layers(layers, self.config.num_hidden_layers) + self.config.eagle_tap_layers = list(layers) + + def forward_logits_taps( + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + last_logits_only: bool = True, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Return soft-capped logits and EAGLE-3 tap features. + + Defaults to final-position logits. Set ``last_logits_only=False`` to + materialize per-position float32 logits over the full vocabulary. + + Returns: + logits: (B, 1, vocab_size) soft-capped float32, or (B, T, vocab_size) + when ``last_logits_only=False``. + taps: (B, T, len(eagle_tap_layers) * hidden_size) or None. + """ + x, taps = self._decode(tokens, input_pos, collect_taps=True) + if last_logits_only: + x = x[:, -1:, :] + logits = self.lm_head(x).float() + cap = self.logit_softcap.float() + logits = torch.tanh(logits / cap) * cap + return logits, taps + # ---------------- checkpoint loading ---------------- @staticmethod diff --git a/examples/models/gemma4_31b/test_eagle_tap.py b/examples/models/gemma4_31b/test_eagle_tap.py new file mode 100644 index 00000000000..8cfafeb9402 --- /dev/null +++ b/examples/models/gemma4_31b/test_eagle_tap.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for the gemma4-31B EAGLE-3 hidden-state tap. + +Covers the tap-index convention (HF/vLLM: index 0 = embedding, index k = output +after decoder layer k-1), exact concatenation order/content, config validation +(including the runtime-mutation path), and that the default decode path is +unaffected by enabling the tap. +""" + +import pytest +import torch + +from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig + + +def tiny_config(num_layers=6, tap_layers=None) -> Gemma4_31BConfig: + return Gemma4_31BConfig( + vocab_size=128, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + num_global_key_value_heads=1, + global_head_dim=8, + sliding_window=8, + max_seq_len=32, + eagle_tap_layers=tap_layers or [], + ) + + +def build(num_layers=6, tap_layers=None): + torch.manual_seed(0) + return Gemma4_31B(tiny_config(num_layers, tap_layers)).to(torch.float32).eval() + + +def reset_kv(model): + """Zero the (stateful) KV caches so independent forwards don't couple.""" + for name, buf in model.named_buffers(): + if ".kv_cache." in name: + buf.zero_() + + +def reference_states(model, tokens, input_pos): + """Recompute _decode's per-index states: 0=embedding, k=after layer k-1.""" + x = model.embed_tokens(tokens) * model.embed_normalizer + states = {0: x} + sliding_mask, full_mask = model._build_masks(input_pos) + for i, layer in enumerate(model.layers): + x = layer(x, input_pos, sliding_mask, full_mask) + states[i + 1] = x + return states + + +def test_tap_off_does_not_change_logits(): + model = build(tap_layers=[1, 2, 3]) + T = 7 + tokens = torch.randint(0, 128, (1, T)) + pos = torch.arange(T) + with torch.no_grad(): + reset_kv(model) + logits_on, taps_on = model.forward_logits_taps( + tokens, pos, last_logits_only=False + ) + model.config.eagle_tap_layers = [] + reset_kv(model) + logits_off, taps_off = model.forward_logits_taps( + tokens, pos, last_logits_only=False + ) + assert taps_off is None + assert taps_on.shape == (1, T, 3 * model.config.hidden_size) + torch.testing.assert_close(logits_on, logits_off) + + +@pytest.mark.parametrize( + "num_layers,tap_layers", + [ + (6, [0, 1, 3]), + (60, [2, 30, 57]), + ], +) +def test_tap_collects_exact_states_in_order(num_layers, tap_layers): + model = build(num_layers=num_layers, tap_layers=tap_layers) + T = 5 + tokens = torch.randint(0, 128, (1, T)) + pos = torch.arange(T) + with torch.no_grad(): + reset_kv(model) + _, taps = model.forward_logits_taps(tokens, pos) + reset_kv(model) + states = reference_states(model, tokens, pos) + expected = torch.cat([states[i] for i in tap_layers], dim=-1) + assert taps.shape == (1, T, len(tap_layers) * model.config.hidden_size) + torch.testing.assert_close(taps, expected, rtol=0, atol=0) + + +def test_last_logits_only_default_matches_full(): + model = build(tap_layers=[1]) + T = 4 + tokens = torch.randint(0, 128, (1, T)) + pos = torch.arange(T) + with torch.no_grad(): + reset_kv(model) + full, _ = model.forward_logits_taps(tokens, pos, last_logits_only=False) + reset_kv(model) + last, _ = model.forward_logits_taps(tokens, pos) + assert last.shape == (1, 1, model.config.vocab_size) + torch.testing.assert_close(last[:, 0], full[:, -1]) + + +@pytest.mark.parametrize("bad", [[99], [1, 1], [1.0, 2], [True], [3, 1]]) +def test_invalid_tap_config_rejected(bad): + with pytest.raises(ValueError): + tiny_config(num_layers=6, tap_layers=bad) + + +def test_set_eagle_tap_layers_validates(): + model = build() + model.set_eagle_tap_layers([0, 2, 4]) + assert model.config.eagle_tap_layers == [0, 2, 4] + with pytest.raises(ValueError): + model.set_eagle_tap_layers([4, 2]) + + +def test_runtime_mutation_is_revalidated_in_decode(): + model = build(tap_layers=[1, 2]) + model.config.eagle_tap_layers = [True] + tokens = torch.randint(0, 128, (1, 4)) + pos = torch.arange(4) + with pytest.raises(ValueError): + model.forward_logits_taps(tokens, pos, last_logits_only=False) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-q"])) diff --git a/pytest.ini b/pytest.ini index 3dd960987f2..3de7ca87dc5 100644 --- a/pytest.ini +++ b/pytest.ini @@ -97,6 +97,7 @@ testpaths = examples/models/llama3_2_vision/text_decoder/test examples/models/llava/test examples/models/eagle3/test_draft.py + examples/models/gemma4_31b/test_eagle_tap.py # exir exir/