From f86468129433b9bf5a2d567e8425fe424def46f3 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Thu, 21 May 2026 10:39:53 -0500 Subject: [PATCH 1/4] Add concurrency and retry solutions for the 429 HF calls --- .github/workflows/checks.yml | 13 ++ README.md | 4 +- docs/source/content/getting_started.md | 34 ++++- tests/conftest.py | 14 ++ tests/unit/utilities/test_hf_utils.py | 183 +++++++++++++++++++++++++ transformer_lens/__init__.py | 11 ++ transformer_lens/utilities/hf_utils.py | 111 ++++++++++++++- 7 files changed, 364 insertions(+), 6 deletions(-) create mode 100644 tests/unit/utilities/test_hf_utils.py diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 3cd9cfe2c..0749bba3f 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -43,6 +43,19 @@ permissions: actions: write contents: write +# Cancel in-progress runs on the same PR when a new push arrives. +# Push-to-main, tag, and workflow_call events are not cancelled so that +# release and deploy jobs always run to completion. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +# Enable retry-on-429 for any code path that loads from HuggingFace. +# Pytest also enables retry via tests/conftest.py; this env var covers any +# non-pytest invocations (scripts, notebooks, docs builds, etc.). +env: + TRANSFORMERLENS_HF_RETRY: "1" + jobs: compatibility-checks: name: Compatibility Checks diff --git a/README.md b/README.md index 7c79cc390..9c129da44 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ CD](https://github.com/TransformerLensOrg/TransformerLens/actions/workflows/chec [![Docs CD](https://github.com/TransformerLensOrg/TransformerLens/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/TransformerLensOrg/TransformerLens/actions/workflows/pages/pages-build-deployment) -A Library for Mechanistic Interpretability of Generative Language Models. Maintained by [Bryce Meyer](https://github.com/bryce13950) and created by [Neel Nanda](https://neelnanda.io/about) +A Library for Mechanistic Interpretability of Generative Language Models. Maintained by [Bryce Meyer](https://github.com/bryce13950) and [Jonah Larson](https://github.com/jlarson4); created by [Neel Nanda](https://neelnanda.io/about) [![Read the Docs Here](https://img.shields.io/badge/-Read%20the%20Docs%20Here-blue?style=for-the-badge&logo=Read-the-Docs&logoColor=white&link=https://TransformerLensOrg.github.io/TransformerLens/)](https://TransformerLensOrg.github.io/TransformerLens/) @@ -50,6 +50,8 @@ bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") logits, activations = bridge.run_with_cache("Hello World") ``` +> Gated models (Llama, Mistral, Gemma, ...) require `HF_TOKEN` in your environment. See [Environment Variables](https://TransformerLensOrg.github.io/TransformerLens/content/getting_started.html#environment-variables) for the full list. + `TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. By default it preserves raw HuggingFace weights – logits and activations match HF, *not* legacy `HookedTransformer` (which folds LayerNorm and centers weights by default). Call `bridge.enable_compatibility_mode()` after booting for HookedTransformer-equivalent numerics. The legacy `HookedTransformer.from_pretrained` API is still available but deprecated — see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide. ## Key Tutorials diff --git a/docs/source/content/getting_started.md b/docs/source/content/getting_started.md index 181ae24ce..f2b9b1f0b 100644 --- a/docs/source/content/getting_started.md +++ b/docs/source/content/getting_started.md @@ -35,8 +35,36 @@ The bridge currently covers 50+ architectures spanning Llama, Mistral, Qwen, Gem The bridge is organized around a small set of generalized components wired together by an architecture adapter, which keeps the model code much easier to navigate than the older unified implementation. For a tour of the bridge's canonical hook names, the component layout, and the expected tensor shapes at each hook point, see the [Model Structure](model_structure.md) page. A small alias layer preserves the older TransformerLens hook names (e.g. `blocks.{i}.hook_resid_pre`) so legacy notebooks keep working — but new code should prefer the canonical names. -## Huggingface Gated Access +## Environment Variables + +TransformerLens reads a handful of environment variables. None are required for basic use; each enables a specific opt-in behavior. + +### `HF_TOKEN` + +Your [HuggingFace access token](https://huggingface.co/settings/tokens). Required for gated models (Llama, Mistral/Mixtral, Gemma families, and others) and used to authenticate any HuggingFace API call TransformerLens makes on your behalf. You will need to accept any model-specific agreements on the HuggingFace Hub before TransformerLens can load a gated model; if you skip this step, the error message will link you directly to the agreement page. + +```bash +export HF_TOKEN="hf_..." +``` + +### `TRANSFORMERLENS_HF_RETRY` + +Set to `"1"` to wrap `transformers.AutoConfig.from_pretrained`, `AutoModel.from_pretrained`, `AutoTokenizer.from_pretrained`, `AutoProcessor.from_pretrained`, and `AutoFeatureExtractor.from_pretrained` with a retry-on-429 helper. When HuggingFace returns HTTP 429 (rate-limited), the call is retried up to three times with exponential backoff, honoring the `Retry-After` response header when present. -Some of the models available in TransformerLens require gated access to be used. Luckily TransformerLens provides a way to access those models via the configuration of an environmental variable. Simply configure your [HuggingFace access token](https://huggingface.co/settings/tokens) as `HF_TOKEN` in your environment. +Intended primarily for CI environments where parallel workflow runs can trip HF's rate limits. Off by default so production callers see unmodified `transformers` behavior. The wrapping is idempotent and applied globally to the class methods; see [`enable_hf_retry`](https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/utilities/hf_utils.py) for the implementation. The TransformerLens test suite enables this automatically via `tests/conftest.py`. + +```bash +export TRANSFORMERLENS_HF_RETRY=1 +``` + +### `TRANSFORMERLENS_ALLOW_MPS` + +Set to `"1"` to opt in to Apple Silicon (MPS) as a target device for model inference. Off by default because not all PyTorch operations used by TransformerLens have stable MPS implementations across PyTorch versions; if you enable this and hit a backend error, the most reliable fallback is to leave the variable unset and let TransformerLens select CPU instead. + +```bash +export TRANSFORMERLENS_ALLOW_MPS=1 +``` + +## Huggingface Gated Access -You will need to make sure you accept the agreements for any gated models, but once you do, the models will work with TransformerLens without issue. If you attempt to use one of these models before you have accepted any related agreements, the console output will be very helpful and point you to the URL where you need to accept an agreement. The most popular gated families supported by TransformerLens are the Llama, Mistral/Mixtral, and Gemma models. +For convenience, gated-model access depends only on `HF_TOKEN` above. Once you have set the token and accepted any model-specific agreements on the HuggingFace Hub, gated models load through TransformerLens with no additional configuration. The most popular gated families supported by TransformerLens are the Llama, Mistral/Mixtral, and Gemma models. diff --git a/tests/conftest.py b/tests/conftest.py index 3229823af..6b2a76389 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,20 @@ def pytest_configure(config): torch.cuda.manual_seed_all(42) +@pytest.fixture(autouse=True, scope="session") +def _enable_hf_retry_for_tests(): + """Wrap HuggingFace Auto*.from_pretrained with retry-on-429 for the entire + test session. + + Deferred to fixture (rather than pytest_configure) so jaxtyping's import + hook can instrument transformer_lens before we import the helper. + """ + from transformer_lens.utilities.hf_utils import enable_hf_retry + + enable_hf_retry() + yield + + def pytest_sessionfinish(session, exitstatus): """Clean up at the end of test session.""" if torch.cuda.is_available(): diff --git a/tests/unit/utilities/test_hf_utils.py b/tests/unit/utilities/test_hf_utils.py new file mode 100644 index 000000000..7d506805f --- /dev/null +++ b/tests/unit/utilities/test_hf_utils.py @@ -0,0 +1,183 @@ +"""Tests for the HuggingFace Hub 429 retry helper.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, List + +import pytest + +from transformer_lens.utilities import hf_utils +from transformer_lens.utilities.hf_utils import call_hf_with_retry + + +class _FakeHTTPError(Exception): + """Stand-in for HfHubHTTPError / requests.HTTPError — exposes .response.status_code.""" + + def __init__(self, status_code: int, retry_after: str | None = None) -> None: + super().__init__(f"HTTP {status_code}") + headers: dict[str, str] = {} + if retry_after is not None: + headers["Retry-After"] = retry_after + self.response = SimpleNamespace(status_code=status_code, headers=headers) + + +@pytest.fixture(autouse=True) +def _no_sleep(monkeypatch: pytest.MonkeyPatch) -> List[float]: + """Capture sleep calls and don't actually sleep — keeps tests fast.""" + waits: List[float] = [] + monkeypatch.setattr(hf_utils.time, "sleep", lambda s: waits.append(s)) + return waits + + +@pytest.fixture +def _deterministic_random(monkeypatch: pytest.MonkeyPatch) -> None: + """Force random.random() == 0.5 so jitter factor (0.8 + 0.4*r) == 1.0 exactly. + + Lets backoff tests assert exact values instead of ranges, without coupling + to the specific jitter window. + """ + monkeypatch.setattr(hf_utils.random, "random", lambda: 0.5) + + +def _make_flaky(fail_times: int, exc_factory: Any) -> Any: + """Build a callable that raises `fail_times` then returns 'ok'.""" + state = {"calls": 0} + + def _inner(*args: Any, **kwargs: Any) -> str: + state["calls"] += 1 + if state["calls"] <= fail_times: + raise exc_factory() + return "ok" + + _inner.state = state # type: ignore[attr-defined] + return _inner + + +class TestCallHfWithRetry: + def test_returns_immediately_on_success(self) -> None: + func = _make_flaky(0, lambda: _FakeHTTPError(429)) + assert call_hf_with_retry(func) == "ok" + assert func.state["calls"] == 1 + + def test_retries_on_429_then_succeeds(self, _no_sleep: List[float]) -> None: + func = _make_flaky(2, lambda: _FakeHTTPError(429)) + assert call_hf_with_retry(func, max_attempts=3, base_delay=1.0) == "ok" + assert func.state["calls"] == 3 + assert len(_no_sleep) == 2 + + def test_raises_after_max_attempts(self, _no_sleep: List[float]) -> None: + func = _make_flaky(99, lambda: _FakeHTTPError(429)) + with pytest.raises(_FakeHTTPError): + call_hf_with_retry(func, max_attempts=3, base_delay=1.0) + assert func.state["calls"] == 3 + # Sleeps happen between attempts, not after the final one. + assert len(_no_sleep) == 2 + + def test_non_429_propagates_immediately(self, _no_sleep: List[float]) -> None: + func = _make_flaky(99, lambda: _FakeHTTPError(503)) + with pytest.raises(_FakeHTTPError): + call_hf_with_retry(func, max_attempts=3, base_delay=1.0) + assert func.state["calls"] == 1 + assert _no_sleep == [] + + def test_non_http_exception_propagates_immediately(self, _no_sleep: List[float]) -> None: + def boom() -> None: + raise ValueError("not a network error") + + with pytest.raises(ValueError): + call_hf_with_retry(boom, max_attempts=3, base_delay=1.0) + assert _no_sleep == [] + + def test_honors_retry_after_header(self, _no_sleep: List[float]) -> None: + func = _make_flaky(1, lambda: _FakeHTTPError(429, retry_after="7.5")) + assert call_hf_with_retry(func, max_attempts=3, base_delay=1.0) == "ok" + assert func.state["calls"] == 2 + assert _no_sleep == [7.5] + + def test_falls_back_to_backoff_when_retry_after_unparseable( + self, _no_sleep: List[float], _deterministic_random: None + ) -> None: + func = _make_flaky(1, lambda: _FakeHTTPError(429, retry_after="soon")) + call_hf_with_retry(func, max_attempts=3, base_delay=10.0) + # base_delay * 2**0 * jitter_factor(0.5) = 10 * 1 * 1.0 = 10.0 exactly + assert _no_sleep == [10.0] + + def test_exponential_backoff_grows( + self, _no_sleep: List[float], _deterministic_random: None + ) -> None: + func = _make_flaky(3, lambda: _FakeHTTPError(429)) + with pytest.raises(_FakeHTTPError): + call_hf_with_retry(func, max_attempts=3, base_delay=10.0) + # Two backoffs between three attempts; last attempt has no sleep. + # attempt 0: 10 * 2**0 * 1.0 = 10; attempt 1: 10 * 2**1 * 1.0 = 20. + assert _no_sleep == [10.0, 20.0] + + def test_backoff_capped_at_max_delay( + self, _no_sleep: List[float], _deterministic_random: None + ) -> None: + """A huge base_delay must be clamped by _HF_RETRY_MAX_DELAY_SECONDS.""" + func = _make_flaky(1, lambda: _FakeHTTPError(429)) + call_hf_with_retry(func, max_attempts=2, base_delay=10_000.0) + # Without cap: 10000 * 2**0 * 1.0 = 10000s. With 120s cap: exactly 120.0. + assert _no_sleep == [hf_utils._HF_RETRY_MAX_DELAY_SECONDS] + + +class TestEnableHfRetry: + """Verify the global Auto*.from_pretrained wrapper installed by enable_hf_retry.""" + + def test_session_fixture_wraps_autoconfig(self) -> None: + """tests/conftest.py:_enable_hf_retry_for_tests must have wrapped AutoConfig.""" + from transformers import AutoConfig + + assert getattr( + AutoConfig.from_pretrained, hf_utils._TL_RETRY_WRAPPED_ATTR, False + ), "enable_hf_retry was not applied to AutoConfig — check conftest fixture" + + def test_session_fixture_wraps_autotokenizer(self) -> None: + from transformers import AutoTokenizer + + assert getattr( + AutoTokenizer.from_pretrained, hf_utils._TL_RETRY_WRAPPED_ATTR, False + ) + + def test_idempotent(self) -> None: + """A second enable_hf_retry call must not re-wrap (or otherwise break) the classes.""" + from transformers import AutoConfig + + before = AutoConfig.from_pretrained.__func__ + hf_utils.enable_hf_retry() + after = AutoConfig.from_pretrained.__func__ + assert before is after + + +class TestDownloadFileFromHf: + """End-to-end coverage: download_file_from_hf must actually use the retry helper. + + Without this, a refactor that calls hf_hub_download directly again — exactly the + regression this change is meant to prevent — would slip past the unit tests above. + """ + + def test_retries_underlying_hf_hub_download_on_429( + self, + monkeypatch: pytest.MonkeyPatch, + _no_sleep: List[float], + tmp_path: Any, + ) -> None: + fake_file = tmp_path / "data.json" + fake_file.write_text('{"ok": true}') + state = {"calls": 0} + + def fake_hub_download(**kwargs: Any) -> str: + state["calls"] += 1 + if state["calls"] < 2: + raise _FakeHTTPError(429) + return str(fake_file) + + monkeypatch.setattr(hf_utils, "hf_hub_download", fake_hub_download) + + result = hf_utils.download_file_from_hf("any/repo", "data.json") + + assert result == {"ok": True} + assert state["calls"] == 2 + assert len(_no_sleep) == 1 diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index e889d52cb..118f0b60c 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -31,6 +31,17 @@ from .SVDInterpreter import SVDInterpreter +# Opt-in: wrap transformers Auto*.from_pretrained with retry-on-429. +# Set TRANSFORMERLENS_HF_RETRY=1 in environments that hit HuggingFace rate limits +# (typically CI). Off by default so normal users see unmodified HF behavior. +# See transformer_lens.utilities.hf_utils.enable_hf_retry for details. +import os as _os # noqa: E402 + +if _os.environ.get("TRANSFORMERLENS_HF_RETRY") == "1": + from .utilities.hf_utils import enable_hf_retry as _enable_hf_retry # noqa: E402 + + _enable_hf_retry() + __all__ = [ "HookedTransformerConfig", "FactoredMatrix", diff --git a/transformer_lens/utilities/hf_utils.py b/transformer_lens/utilities/hf_utils.py index 580255b1a..69cd9c063 100644 --- a/transformer_lens/utilities/hf_utils.py +++ b/transformer_lens/utilities/hf_utils.py @@ -8,10 +8,13 @@ import errno import inspect import json +import logging import os +import random import shutil import stat -from typing import Any, Callable, Dict +import time +from typing import Any, Callable, Dict, TypeVar import torch from datasets.arrow_dataset import Dataset @@ -21,6 +24,109 @@ from huggingface_hub.constants import HF_HUB_CACHE CACHE_DIR = HF_HUB_CACHE +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +_HF_RETRY_MAX_ATTEMPTS = 3 +_HF_RETRY_BASE_DELAY_SECONDS = 10.0 +_HF_RETRY_MAX_DELAY_SECONDS = 120.0 + + +def _is_hf_rate_limit_error(exc: BaseException) -> bool: + """Duck-typed check for HTTP 429 — covers HfHubHTTPError, requests.HTTPError, and subclasses.""" + response = getattr(exc, "response", None) + return response is not None and getattr(response, "status_code", None) == 429 + + +def _retry_after_seconds(exc: BaseException) -> float | None: + """Parse the Retry-After header from a 429 response, if present and numeric.""" + response = getattr(exc, "response", None) + if response is None: + return None + headers = getattr(response, "headers", None) or {} + raw = headers.get("Retry-After") if hasattr(headers, "get") else None + if raw is None: + return None + try: + return float(raw) + except (TypeError, ValueError): + return None + + +_TL_RETRY_WRAPPED_ATTR = "_tl_hf_retry_wrapped" + + +def enable_hf_retry() -> None: + """Globally wrap transformers ``Auto*.from_pretrained`` with retry-on-429. + + After calling this, every load through ``AutoConfig.from_pretrained``, + ``AutoModel.from_pretrained``, ``AutoTokenizer.from_pretrained``, + ``AutoProcessor.from_pretrained``, or ``AutoFeatureExtractor.from_pretrained`` + will go through :func:`call_hf_with_retry`, retrying on HTTP 429 with + exponential backoff (honoring the ``Retry-After`` header when present). + + Intended for CI / test environments that hit HF rate limits during parallel + workflow runs. Opt-in via the ``TRANSFORMERLENS_HF_RETRY=1`` environment + variable or by calling this function explicitly. Not enabled by default so + that production callers see unmodified ``transformers`` behavior. + + Idempotent: safe to call multiple times; subsequent calls are no-ops. + """ + from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoModel, + AutoProcessor, + AutoTokenizer, + ) + + for cls in (AutoConfig, AutoModel, AutoTokenizer, AutoProcessor, AutoFeatureExtractor): + original = cls.from_pretrained + if getattr(original, _TL_RETRY_WRAPPED_ATTR, False): + continue + underlying = original.__func__ if hasattr(original, "__func__") else original + + def _wrapped(klass, *args: Any, _orig: Any = underlying, **kwargs: Any) -> Any: + return call_hf_with_retry(_orig, klass, *args, **kwargs) + + setattr(_wrapped, _TL_RETRY_WRAPPED_ATTR, True) + cls.from_pretrained = classmethod(_wrapped) + + +def call_hf_with_retry( + func: Callable[..., T], + *args: Any, + max_attempts: int = _HF_RETRY_MAX_ATTEMPTS, + base_delay: float = _HF_RETRY_BASE_DELAY_SECONDS, + **kwargs: Any, +) -> T: + """Call ``func(*args, **kwargs)``, retrying on HTTP 429 from HuggingFace Hub. + + Behavior: + - Honors the ``Retry-After`` response header when present. + - Otherwise uses exponential backoff with ±20% jitter, capped at + ``_HF_RETRY_MAX_DELAY_SECONDS``. + - Only retries on 429; all other exceptions propagate immediately. + """ + for attempt in range(max_attempts): + try: + return func(*args, **kwargs) + except Exception as exc: + if not _is_hf_rate_limit_error(exc) or attempt == max_attempts - 1: + raise + wait = _retry_after_seconds(exc) + if wait is None: + wait = min(base_delay * (2**attempt), _HF_RETRY_MAX_DELAY_SECONDS) + wait *= 0.8 + 0.4 * random.random() + logger.warning( + "HuggingFace Hub rate-limited (HTTP 429); retrying in %.1fs (attempt %d/%d)", + wait, + attempt + 1, + max_attempts, + ) + time.sleep(wait) + raise RuntimeError("call_hf_with_retry exited loop without returning or raising") def get_hf_token() -> str | None: @@ -75,7 +181,8 @@ def download_file_from_hf( If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. """ - file_path = hf_hub_download( + file_path = call_hf_with_retry( + hf_hub_download, repo_id=repo_name, filename=file_name, subfolder=subfolder, From e95a4309b7ce4c3bfe88453e85587a1be0a7c554 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Thu, 21 May 2026 11:28:05 -0500 Subject: [PATCH 2/4] reduce individual huggingface calls in favor of using fixtures or mocks --- tests/acceptance/conftest.py | 34 +++++++++++++ tests/acceptance/model_bridge/conftest.py | 19 +++++-- tests/acceptance/test_hooked_transformer.py | 25 +++++----- tests/conftest.py | 32 ++++++++++++ tests/integration/model_bridge/conftest.py | 29 +++++++++-- .../test_bridge_creation_modes.py | 50 +++++++++---------- tests/unit/model_bridge/test_key_analysis.py | 4 +- .../test_weight_processing_adapter_paths.py | 4 +- tests/unit/test_n_params_total.py | 4 +- tests/unit/test_utils.py | 11 ++-- 10 files changed, 158 insertions(+), 54 deletions(-) diff --git a/tests/acceptance/conftest.py b/tests/acceptance/conftest.py index 50ad394a6..a56bfe356 100644 --- a/tests/acceptance/conftest.py +++ b/tests/acceptance/conftest.py @@ -2,6 +2,10 @@ Session-scoped fixtures avoid redundant model loads across test files. All models used here must be in the CI cache (see .github/workflows/checks.yml). + +NB: imports of ``transformer_lens`` are deferred into fixture bodies so that +jaxtyping's pytest_configure import hook can install before the package is +first imported. """ import pytest @@ -13,3 +17,33 @@ def gpt2_model(): from transformer_lens import HookedTransformer return HookedTransformer.from_pretrained("gpt2", device="cpu") + + +@pytest.fixture(scope="session") +def bloom_560m_hooked(): + """Session-scoped HookedTransformer for bigscience/bloom-560m. + + Loaded with ``default_prepend_bos=False`` to match what the bloom-similarity + tests expect. Bloom-560m is ~1.2 GB so sharing is meaningful. + """ + from transformer_lens import HookedTransformer + + return HookedTransformer.from_pretrained( + "bigscience/bloom-560m", default_prepend_bos=False, device="cpu" + ) + + +@pytest.fixture(scope="session") +def bloom_560m_hf_model(): + """Session-scoped raw HuggingFace bloom-560m model (for parity checks).""" + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") + + +@pytest.fixture(scope="session") +def bloom_560m_hf_tokenizer(): + """Session-scoped bloom-560m tokenizer.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained("bigscience/bloom-560m") diff --git a/tests/acceptance/model_bridge/conftest.py b/tests/acceptance/model_bridge/conftest.py index 7012b890d..e99e26e0b 100644 --- a/tests/acceptance/model_bridge/conftest.py +++ b/tests/acceptance/model_bridge/conftest.py @@ -2,23 +2,30 @@ Session-scoped fixtures avoid redundant model loads across test files. All models used here must be in the CI cache (see .github/workflows/checks.yml). + +NB: imports of ``transformer_lens`` are deferred into fixture bodies so that +jaxtyping's pytest_configure import hook can install before the package is +first imported. Module-level imports here break running these tests in +isolation (RuntimeError: "jaxtyping cannot check these packages because they +are already imported"). """ import pytest -from transformer_lens import HookedTransformer -from transformer_lens.model_bridge import TransformerBridge - @pytest.fixture(scope="session") def gpt2_bridge(): """TransformerBridge wrapping gpt2 (no compatibility mode).""" + from transformer_lens.model_bridge import TransformerBridge + return TransformerBridge.boot_transformers("gpt2", device="cpu") @pytest.fixture(scope="session") def gpt2_bridge_compat(): """TransformerBridge wrapping gpt2 with compatibility mode enabled.""" + from transformer_lens.model_bridge import TransformerBridge + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") bridge.enable_compatibility_mode() return bridge @@ -27,6 +34,8 @@ def gpt2_bridge_compat(): @pytest.fixture(scope="session") def gpt2_bridge_compat_no_processing(): """TransformerBridge wrapping gpt2 with compatibility mode but no weight processing.""" + from transformer_lens.model_bridge import TransformerBridge + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") bridge.enable_compatibility_mode(no_processing=True) return bridge @@ -35,10 +44,14 @@ def gpt2_bridge_compat_no_processing(): @pytest.fixture(scope="session") def gpt2_hooked_processed(): """HookedTransformer gpt2 with default weight processing.""" + from transformer_lens import HookedTransformer + return HookedTransformer.from_pretrained("gpt2", device="cpu") @pytest.fixture(scope="session") def gpt2_hooked_unprocessed(): """HookedTransformer gpt2 without weight processing.""" + from transformer_lens import HookedTransformer + return HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu") diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index e830c5bd7..8020c1923 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -216,12 +216,12 @@ def test_from_pretrained_revision(): raise AssertionError("Should have raised an error") -def test_bloom_similarity_with_hf_model_with_kv_cache_activated(): - tf_model = HookedTransformer.from_pretrained( - "bigscience/bloom-560m", default_prepend_bos=False, device="cpu" - ) - hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") - hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") +def test_bloom_similarity_with_hf_model_with_kv_cache_activated( + bloom_560m_hooked, bloom_560m_hf_model, bloom_560m_hf_tokenizer +): + tf_model = bloom_560m_hooked + hf_model = bloom_560m_hf_model + hf_tokenizer = bloom_560m_hf_tokenizer output_tf = tf_model.generate( text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10 @@ -236,13 +236,12 @@ def test_bloom_similarity_with_hf_model_with_kv_cache_activated(): assert output_tf == output_hf_str -def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream(): - tf_model = HookedTransformer.from_pretrained( - "bigscience/bloom-560m", default_prepend_bos=False, device="cpu" - ) - - hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") - hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") +def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream( + bloom_560m_hooked, bloom_560m_hf_model, bloom_560m_hf_tokenizer +): + tf_model = bloom_560m_hooked + hf_model = bloom_560m_hf_model + hf_tokenizer = bloom_560m_hf_tokenizer final_output = "" for result in tf_model.generate_stream( diff --git a/tests/conftest.py b/tests/conftest.py index 6b2a76389..4e67489fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,6 +64,38 @@ def _enable_hf_retry_for_tests(): yield +@pytest.fixture(scope="session") +def gpt2_tokenizer(): + """Session-scoped GPT-2 tokenizer (no add_bos_token). + + Shared across the unit, integration, and acceptance test trees to avoid + repeated AutoTokenizer.from_pretrained("gpt2") calls — each one triggers + a HuggingFace Hub freshness check even when the tokenizer is cached. + Tokenizers are immutable for read-only use, so session scope is safe. + """ + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained("gpt2") + + +@pytest.fixture(scope="session") +def gpt2_hooked_processed(): + """Session-scoped HookedTransformer gpt2 with default weight processing. + + Top-level fixture for unit tests and any other consumer without a closer + fixture in scope. Sub-conftests in tests/acceptance/model_bridge/ and + tests/integration/model_bridge/ define their own same-named fixture, + which shadows this one within those subtrees. + + Safe for read-only use: ``.parameters()``, ``.state_dict()``, + ``.to_tokens()``, ``.cfg``. Do NOT mutate (no ``.process_weights_()``, + no permanent hooks, no ``.train()``/``.eval()`` that you don't restore). + """ + from transformer_lens import HookedTransformer + + return HookedTransformer.from_pretrained("gpt2", device="cpu") + + def pytest_sessionfinish(session, exitstatus): """Clean up at the end of test session.""" if torch.cuda.is_available(): diff --git a/tests/integration/model_bridge/conftest.py b/tests/integration/model_bridge/conftest.py index 56a9a1945..8d9bcfe7e 100644 --- a/tests/integration/model_bridge/conftest.py +++ b/tests/integration/model_bridge/conftest.py @@ -2,23 +2,30 @@ Session-scoped fixtures avoid redundant model loads across test files. All models used here must be in the CI cache (see .github/workflows/checks.yml). + +NB: imports of ``transformer_lens`` are deferred into fixture bodies so that +jaxtyping's pytest_configure import hook can install before the package is +first imported. Module-level imports here break running these tests in +isolation (RuntimeError: "jaxtyping cannot check these packages because they +are already imported"). """ import pytest -from transformer_lens import HookedTransformer -from transformer_lens.model_bridge.bridge import TransformerBridge - @pytest.fixture(scope="session") def distilgpt2_bridge(): """TransformerBridge wrapping distilgpt2 (no compatibility mode).""" + from transformer_lens.model_bridge.bridge import TransformerBridge + return TransformerBridge.boot_transformers("distilgpt2", device="cpu") @pytest.fixture(scope="session") def distilgpt2_bridge_compat(): """TransformerBridge wrapping distilgpt2 with compatibility mode enabled.""" + from transformer_lens.model_bridge.bridge import TransformerBridge + bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") bridge.enable_compatibility_mode() return bridge @@ -27,12 +34,16 @@ def distilgpt2_bridge_compat(): @pytest.fixture(scope="session") def gpt2_bridge(): """TransformerBridge wrapping gpt2 (no compatibility mode).""" + from transformer_lens.model_bridge.bridge import TransformerBridge + return TransformerBridge.boot_transformers("gpt2", device="cpu") @pytest.fixture(scope="session") def gpt2_bridge_compat(): """TransformerBridge wrapping gpt2 with compatibility mode enabled.""" + from transformer_lens.model_bridge.bridge import TransformerBridge + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") bridge.enable_compatibility_mode() return bridge @@ -41,30 +52,40 @@ def gpt2_bridge_compat(): @pytest.fixture(scope="session") def gpt2_hooked_processed(): """HookedTransformer gpt2 with default weight processing.""" + from transformer_lens import HookedTransformer + return HookedTransformer.from_pretrained("gpt2", device="cpu") @pytest.fixture(scope="session") def gpt2_hooked_unprocessed(): """HookedTransformer gpt2 without weight processing.""" + from transformer_lens import HookedTransformer + return HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu") @pytest.fixture(scope="session") def distilgpt2_hooked_processed(): """HookedTransformer distilgpt2 with default weight processing.""" + from transformer_lens import HookedTransformer + return HookedTransformer.from_pretrained("distilgpt2", device="cpu") @pytest.fixture(scope="session") def distilgpt2_hooked_unprocessed(): """HookedTransformer distilgpt2 without weight processing.""" + from transformer_lens import HookedTransformer + return HookedTransformer.from_pretrained_no_processing("distilgpt2", device="cpu") @pytest.fixture(scope="session") def gpt2_bridge_compat_no_processing(): """TransformerBridge wrapping gpt2 with compat mode, no weight processing.""" + from transformer_lens.model_bridge.bridge import TransformerBridge + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") bridge.enable_compatibility_mode(no_processing=True) return bridge @@ -73,6 +94,8 @@ def gpt2_bridge_compat_no_processing(): @pytest.fixture(scope="session") def distilgpt2_bridge_compat_no_processing(): """TransformerBridge wrapping distilgpt2 with compat mode, no weight processing.""" + from transformer_lens.model_bridge.bridge import TransformerBridge + bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") bridge.enable_compatibility_mode(no_processing=True, disable_warnings=True) return bridge diff --git a/tests/integration/model_bridge/test_bridge_creation_modes.py b/tests/integration/model_bridge/test_bridge_creation_modes.py index 45d5e86d4..fbf0b969a 100644 --- a/tests/integration/model_bridge/test_bridge_creation_modes.py +++ b/tests/integration/model_bridge/test_bridge_creation_modes.py @@ -3,29 +3,24 @@ import pytest import torch -from transformer_lens import HookedTransformer from transformer_lens.model_bridge.bridge import TransformerBridge class TestBridgeCreationModes: """Test different modes of creating and configuring TransformerBridge.""" - @pytest.fixture - def reference_model(self): - """Create reference HookedTransformer.""" - return HookedTransformer.from_pretrained("distilgpt2", device="cpu") - @pytest.fixture def test_text(self): """Test text for evaluation.""" return "Hello world" - def test_bridge_no_processing(self, reference_model, test_text): + def test_bridge_no_processing( + self, distilgpt2_bridge_compat_no_processing, distilgpt2_hooked_processed, test_text + ): """Test bridge with no weight processing.""" - bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") - bridge.enable_compatibility_mode(no_processing=True) + bridge = distilgpt2_bridge_compat_no_processing - ref_loss = reference_model(test_text, return_type="loss") + ref_loss = distilgpt2_hooked_processed(test_text, return_type="loss") bridge_loss = bridge(test_text, return_type="loss") # With no processing, losses should be close but not identical @@ -34,12 +29,13 @@ def test_bridge_no_processing(self, reference_model, test_text): ), f"Losses should be reasonably close: {ref_loss} vs {bridge_loss}" assert 3.0 < bridge_loss < 8.0, f"Bridge loss should be reasonable: {bridge_loss}" - def test_bridge_full_compatibility(self, reference_model, test_text): + def test_bridge_full_compatibility( + self, distilgpt2_bridge_compat, distilgpt2_hooked_processed, test_text + ): """Test bridge with full compatibility mode processing.""" - bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") - bridge.enable_compatibility_mode() + bridge = distilgpt2_bridge_compat - ref_loss = reference_model(test_text, return_type="loss") + ref_loss = distilgpt2_hooked_processed(test_text, return_type="loss") bridge_loss = bridge(test_text, return_type="loss") # With full processing, losses should be very close @@ -47,9 +43,9 @@ def test_bridge_full_compatibility(self, reference_model, test_text): assert diff < 0.01, f"Processed bridge should match reference closely: {diff}" assert 3.0 < bridge_loss < 8.0, f"Bridge loss should be reasonable: {bridge_loss}" - def test_bridge_component_inspection(self): + def test_bridge_component_inspection(self, distilgpt2_bridge): """Test that bridge components can be inspected.""" - bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") + bridge = distilgpt2_bridge # Check that we can access the original model components assert hasattr(bridge.original_model, "transformer"), "Should have transformer" @@ -68,20 +64,24 @@ def test_bridge_component_inspection(self): assert hasattr(bridge.original_model.transformer, "wpe"), "Should have position embedding" assert hasattr(bridge.original_model, "lm_head"), "Should have language model head" - def test_bridge_tokenizer_compatibility(self, reference_model): + def test_bridge_tokenizer_compatibility(self, distilgpt2_bridge, distilgpt2_hooked_processed): """Test that bridge tokenizer works like reference.""" - bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") test_text = "Hello world test" # Tokenize with both - ref_tokens = reference_model.to_tokens(test_text) - bridge_tokens = bridge.to_tokens(test_text) + ref_tokens = distilgpt2_hooked_processed.to_tokens(test_text) + bridge_tokens = distilgpt2_bridge.to_tokens(test_text) # Should produce identical tokens assert torch.equal(ref_tokens, bridge_tokens), "Tokenizers should produce identical results" def test_bridge_configuration_persistence(self): - """Test that bridge configuration persists correctly.""" + """Test that bridge configuration persists across the boot → enable_compat transition. + + Intentional fresh boot: the assertion is that enable_compatibility_mode() does + not destroy ``cfg`` mid-flight. A session-scoped fixture would already have + compat enabled and lose the before/after semantic. + """ bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") # Test configuration before compatibility mode @@ -94,17 +94,15 @@ def test_bridge_configuration_persistence(self): assert hasattr(bridge, "cfg"), "Configuration should persist after compatibility mode" assert bridge.cfg is not None, "Configuration should not be None" - def test_bridge_device_handling(self): + def test_bridge_device_handling(self, gpt2_bridge): """Test that bridge handles device specification correctly.""" - # Test CPU device - bridge_cpu = TransformerBridge.boot_transformers("gpt2", device="cpu") assert ( - next(bridge_cpu.original_model.parameters()).device.type == "cpu" + next(gpt2_bridge.original_model.parameters()).device.type == "cpu" ), "Model should be on CPU device" # Test that bridge can process text on correct device test_text = "Device test" - loss = bridge_cpu(test_text, return_type="loss") + loss = gpt2_bridge(test_text, return_type="loss") assert isinstance(loss, torch.Tensor), "Should return tensor" assert loss.device.type == "cpu", "Loss should be on CPU" diff --git a/tests/unit/model_bridge/test_key_analysis.py b/tests/unit/model_bridge/test_key_analysis.py index 6d9a36db3..e2353cdfb 100644 --- a/tests/unit/model_bridge/test_key_analysis.py +++ b/tests/unit/model_bridge/test_key_analysis.py @@ -137,12 +137,12 @@ def copy(self): return self.state_dict.copy() -def test_key_analysis(): +def test_key_analysis(gpt2_hooked_processed): """Analyze what keys ProcessWeights tries to access.""" print("=== ANALYZING PROCESSWEIGHTS KEY ACCESS ===") print("\n1. Loading models...") - hooked_model = HookedTransformer.from_pretrained("gpt2", device="cpu") + hooked_model = gpt2_hooked_processed hf_model = GPT2LMHeadModel.from_pretrained("gpt2") print("\n2. Getting state dicts...") diff --git a/tests/unit/model_bridge/test_weight_processing_adapter_paths.py b/tests/unit/model_bridge/test_weight_processing_adapter_paths.py index 34cd5bce4..28c64fa60 100644 --- a/tests/unit/model_bridge/test_weight_processing_adapter_paths.py +++ b/tests/unit/model_bridge/test_weight_processing_adapter_paths.py @@ -16,7 +16,7 @@ @pytest.mark.filterwarnings("ignore::pytest.PytestReturnNotNoneWarning") -def test_processweights_with_adapter(): +def test_processweights_with_adapter(gpt2_hooked_processed): """Test ProcessWeights with architecture adapter for path translation.""" print("=== TESTING PROCESSWEIGHTS WITH ARCHITECTURE ADAPTER ===") @@ -24,7 +24,7 @@ def test_processweights_with_adapter(): gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets." print("\n1. Loading reference HookedTransformer...") - hooked_processed = HookedTransformer.from_pretrained("gpt2", device="cpu") + hooked_processed = gpt2_hooked_processed tokens = hooked_processed.to_tokens(gpt2_text) print("\n2. Loading raw HuggingFace model...") diff --git a/tests/unit/test_n_params_total.py b/tests/unit/test_n_params_total.py index 14e5ced6a..ed76c2630 100644 --- a/tests/unit/test_n_params_total.py +++ b/tests/unit/test_n_params_total.py @@ -96,7 +96,7 @@ def test_n_params_total_returns_int(): assert isinstance(model.n_params_total, int) -def test_n_params_total_real_model_gpt2(): +def test_n_params_total_real_model_gpt2(gpt2_hooked_processed): """End-to-end sanity check on a real loaded model (GPT-2, cached by CI). Note: TL's GPT-2 reports more parameters than HuggingFace's because HF ties @@ -109,7 +109,7 @@ def test_n_params_total_real_model_gpt2(): iterating ``model.parameters()`` on the loaded model — i.e. the property correctly reflects what's actually stored. """ - tl = HookedTransformer.from_pretrained("gpt2", device="cpu") + tl = gpt2_hooked_processed expected = sum(p.numel() for p in tl.parameters()) assert tl.n_params_total == expected # Sanity: GPT-2 is ~124M-163M params depending on tying; ours falls in this band. diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 78f7db202..523f43373 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -538,7 +538,7 @@ def test_single_document_batch_does_not_crash(self): n = len(output) assert (output == clean[:n]).all() - def test_iterable_dataset_with_set_format_false(self): + def test_iterable_dataset_with_set_format_false(self, gpt2_tokenizer): """``IterableDataset`` input + ``set_format=False`` returns a usable iterable. Regression test for the path requested in #473: ``set_format(type="torch")`` @@ -547,9 +547,8 @@ def test_iterable_dataset_with_set_format_false(self): """ from datasets import Dataset from datasets.iterable_dataset import IterableDataset - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = gpt2_tokenizer source = Dataset.from_dict({"text": ["hello world"] * 4}) iterable = source.to_iterable_dataset() assert isinstance(iterable, IterableDataset) @@ -612,6 +611,12 @@ def test_tokenize_and_concatenate_no_spurious_sequence_length_warning(): def test_tokenize_and_concatenate_short_sequence_no_invalid_tokens(): """ When the tokenizer has no pad token, output should only contain token IDs in the model's vocab. + + Loads a fresh tokenizer rather than using the session-scoped ``gpt2_tokenizer`` + fixture: this test asserts ``pad_token is None`` at start, but + ``utils.tokenize_and_concatenate(..., add_bos_token=True)`` mutates the + tokenizer's pad_token in-place, so a shared session fixture would carry + state from earlier tests in this file. """ from datasets import Dataset from transformers import AutoTokenizer From b22165f532c2e9fb154db83b37dcb21232fc8389 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Thu, 21 May 2026 11:34:33 -0500 Subject: [PATCH 3/4] Comment cleanup --- .github/workflows/checks.yml | 8 ++---- tests/acceptance/conftest.py | 17 +++--------- tests/acceptance/model_bridge/conftest.py | 12 +++------ tests/conftest.py | 26 ++----------------- tests/integration/model_bridge/conftest.py | 12 +++------ .../test_bridge_creation_modes.py | 7 +---- tests/unit/test_utils.py | 12 +++------ transformer_lens/__init__.py | 4 --- transformer_lens/utilities/hf_utils.py | 23 ++++------------ 9 files changed, 22 insertions(+), 99 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 0749bba3f..122cc9a6a 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -43,16 +43,12 @@ permissions: actions: write contents: write -# Cancel in-progress runs on the same PR when a new push arrives. -# Push-to-main, tag, and workflow_call events are not cancelled so that -# release and deploy jobs always run to completion. +# Cancel in-progress PR runs on new push; non-PR events (release, tags) are exempt. concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ github.event_name == 'pull_request' }} -# Enable retry-on-429 for any code path that loads from HuggingFace. -# Pytest also enables retry via tests/conftest.py; this env var covers any -# non-pytest invocations (scripts, notebooks, docs builds, etc.). +# Retry HF 429s in non-pytest invocations; pytest enables via tests/conftest.py. env: TRANSFORMERLENS_HF_RETRY: "1" diff --git a/tests/acceptance/conftest.py b/tests/acceptance/conftest.py index a56bfe356..2c40a96f7 100644 --- a/tests/acceptance/conftest.py +++ b/tests/acceptance/conftest.py @@ -1,11 +1,7 @@ -"""Shared fixtures for acceptance tests. +"""Session fixtures for acceptance tests. -Session-scoped fixtures avoid redundant model loads across test files. -All models used here must be in the CI cache (see .github/workflows/checks.yml). - -NB: imports of ``transformer_lens`` are deferred into fixture bodies so that -jaxtyping's pytest_configure import hook can install before the package is -first imported. +transformer_lens imports stay inside fixture bodies — jaxtyping's pytest_configure +hook must install before the package is first imported. """ import pytest @@ -21,11 +17,6 @@ def gpt2_model(): @pytest.fixture(scope="session") def bloom_560m_hooked(): - """Session-scoped HookedTransformer for bigscience/bloom-560m. - - Loaded with ``default_prepend_bos=False`` to match what the bloom-similarity - tests expect. Bloom-560m is ~1.2 GB so sharing is meaningful. - """ from transformer_lens import HookedTransformer return HookedTransformer.from_pretrained( @@ -35,7 +26,6 @@ def bloom_560m_hooked(): @pytest.fixture(scope="session") def bloom_560m_hf_model(): - """Session-scoped raw HuggingFace bloom-560m model (for parity checks).""" from transformers import AutoModelForCausalLM return AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") @@ -43,7 +33,6 @@ def bloom_560m_hf_model(): @pytest.fixture(scope="session") def bloom_560m_hf_tokenizer(): - """Session-scoped bloom-560m tokenizer.""" from transformers import AutoTokenizer return AutoTokenizer.from_pretrained("bigscience/bloom-560m") diff --git a/tests/acceptance/model_bridge/conftest.py b/tests/acceptance/model_bridge/conftest.py index e99e26e0b..870b38a7a 100644 --- a/tests/acceptance/model_bridge/conftest.py +++ b/tests/acceptance/model_bridge/conftest.py @@ -1,13 +1,7 @@ -"""Shared fixtures for model_bridge acceptance tests. +"""Session fixtures for model_bridge acceptance tests. -Session-scoped fixtures avoid redundant model loads across test files. -All models used here must be in the CI cache (see .github/workflows/checks.yml). - -NB: imports of ``transformer_lens`` are deferred into fixture bodies so that -jaxtyping's pytest_configure import hook can install before the package is -first imported. Module-level imports here break running these tests in -isolation (RuntimeError: "jaxtyping cannot check these packages because they -are already imported"). +transformer_lens imports stay inside fixture bodies — jaxtyping's pytest_configure +hook must install before the package is first imported. """ import pytest diff --git a/tests/conftest.py b/tests/conftest.py index 4e67489fa..36800eed3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,12 +52,7 @@ def pytest_configure(config): @pytest.fixture(autouse=True, scope="session") def _enable_hf_retry_for_tests(): - """Wrap HuggingFace Auto*.from_pretrained with retry-on-429 for the entire - test session. - - Deferred to fixture (rather than pytest_configure) so jaxtyping's import - hook can instrument transformer_lens before we import the helper. - """ + """Deferred to fixture (not pytest_configure) so jaxtyping installs first.""" from transformer_lens.utilities.hf_utils import enable_hf_retry enable_hf_retry() @@ -66,13 +61,6 @@ def _enable_hf_retry_for_tests(): @pytest.fixture(scope="session") def gpt2_tokenizer(): - """Session-scoped GPT-2 tokenizer (no add_bos_token). - - Shared across the unit, integration, and acceptance test trees to avoid - repeated AutoTokenizer.from_pretrained("gpt2") calls — each one triggers - a HuggingFace Hub freshness check even when the tokenizer is cached. - Tokenizers are immutable for read-only use, so session scope is safe. - """ from transformers import AutoTokenizer return AutoTokenizer.from_pretrained("gpt2") @@ -80,17 +68,7 @@ def gpt2_tokenizer(): @pytest.fixture(scope="session") def gpt2_hooked_processed(): - """Session-scoped HookedTransformer gpt2 with default weight processing. - - Top-level fixture for unit tests and any other consumer without a closer - fixture in scope. Sub-conftests in tests/acceptance/model_bridge/ and - tests/integration/model_bridge/ define their own same-named fixture, - which shadows this one within those subtrees. - - Safe for read-only use: ``.parameters()``, ``.state_dict()``, - ``.to_tokens()``, ``.cfg``. Do NOT mutate (no ``.process_weights_()``, - no permanent hooks, no ``.train()``/``.eval()`` that you don't restore). - """ + """Read-only use only — mutations leak across the session.""" from transformer_lens import HookedTransformer return HookedTransformer.from_pretrained("gpt2", device="cpu") diff --git a/tests/integration/model_bridge/conftest.py b/tests/integration/model_bridge/conftest.py index 8d9bcfe7e..947a1e7cc 100644 --- a/tests/integration/model_bridge/conftest.py +++ b/tests/integration/model_bridge/conftest.py @@ -1,13 +1,7 @@ -"""Shared fixtures for model_bridge integration tests. +"""Session fixtures for model_bridge integration tests. -Session-scoped fixtures avoid redundant model loads across test files. -All models used here must be in the CI cache (see .github/workflows/checks.yml). - -NB: imports of ``transformer_lens`` are deferred into fixture bodies so that -jaxtyping's pytest_configure import hook can install before the package is -first imported. Module-level imports here break running these tests in -isolation (RuntimeError: "jaxtyping cannot check these packages because they -are already imported"). +transformer_lens imports stay inside fixture bodies — jaxtyping's pytest_configure +hook must install before the package is first imported. """ import pytest diff --git a/tests/integration/model_bridge/test_bridge_creation_modes.py b/tests/integration/model_bridge/test_bridge_creation_modes.py index fbf0b969a..9ec45f5e7 100644 --- a/tests/integration/model_bridge/test_bridge_creation_modes.py +++ b/tests/integration/model_bridge/test_bridge_creation_modes.py @@ -76,12 +76,7 @@ def test_bridge_tokenizer_compatibility(self, distilgpt2_bridge, distilgpt2_hook assert torch.equal(ref_tokens, bridge_tokens), "Tokenizers should produce identical results" def test_bridge_configuration_persistence(self): - """Test that bridge configuration persists across the boot → enable_compat transition. - - Intentional fresh boot: the assertion is that enable_compatibility_mode() does - not destroy ``cfg`` mid-flight. A session-scoped fixture would already have - compat enabled and lose the before/after semantic. - """ + # Fresh boot: tests the boot → enable_compat transition. bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") # Test configuration before compatibility mode diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 523f43373..86e685c69 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -609,15 +609,9 @@ def test_tokenize_and_concatenate_no_spurious_sequence_length_warning(): def test_tokenize_and_concatenate_short_sequence_no_invalid_tokens(): - """ - When the tokenizer has no pad token, output should only contain token IDs in the model's vocab. - - Loads a fresh tokenizer rather than using the session-scoped ``gpt2_tokenizer`` - fixture: this test asserts ``pad_token is None`` at start, but - ``utils.tokenize_and_concatenate(..., add_bos_token=True)`` mutates the - tokenizer's pad_token in-place, so a shared session fixture would carry - state from earlier tests in this file. - """ + """When the tokenizer has no pad token, output should only contain token IDs in the model's vocab.""" + # Fresh tokenizer (not gpt2_tokenizer fixture): the function mutates pad_token + # and we assert pad_token is None at the start. from datasets import Dataset from transformers import AutoTokenizer diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 118f0b60c..ccb3f2ebb 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -31,10 +31,6 @@ from .SVDInterpreter import SVDInterpreter -# Opt-in: wrap transformers Auto*.from_pretrained with retry-on-429. -# Set TRANSFORMERLENS_HF_RETRY=1 in environments that hit HuggingFace rate limits -# (typically CI). Off by default so normal users see unmodified HF behavior. -# See transformer_lens.utilities.hf_utils.enable_hf_retry for details. import os as _os # noqa: E402 if _os.environ.get("TRANSFORMERLENS_HF_RETRY") == "1": diff --git a/transformer_lens/utilities/hf_utils.py b/transformer_lens/utilities/hf_utils.py index 69cd9c063..be515cb72 100644 --- a/transformer_lens/utilities/hf_utils.py +++ b/transformer_lens/utilities/hf_utils.py @@ -60,18 +60,8 @@ def _retry_after_seconds(exc: BaseException) -> float | None: def enable_hf_retry() -> None: """Globally wrap transformers ``Auto*.from_pretrained`` with retry-on-429. - After calling this, every load through ``AutoConfig.from_pretrained``, - ``AutoModel.from_pretrained``, ``AutoTokenizer.from_pretrained``, - ``AutoProcessor.from_pretrained``, or ``AutoFeatureExtractor.from_pretrained`` - will go through :func:`call_hf_with_retry`, retrying on HTTP 429 with - exponential backoff (honoring the ``Retry-After`` header when present). - - Intended for CI / test environments that hit HF rate limits during parallel - workflow runs. Opt-in via the ``TRANSFORMERLENS_HF_RETRY=1`` environment - variable or by calling this function explicitly. Not enabled by default so - that production callers see unmodified ``transformers`` behavior. - - Idempotent: safe to call multiple times; subsequent calls are no-ops. + Opt-in via ``TRANSFORMERLENS_HF_RETRY=1`` or by calling this function. + Idempotent. See :func:`call_hf_with_retry`. """ from transformers import ( AutoConfig, @@ -101,13 +91,10 @@ def call_hf_with_retry( base_delay: float = _HF_RETRY_BASE_DELAY_SECONDS, **kwargs: Any, ) -> T: - """Call ``func(*args, **kwargs)``, retrying on HTTP 429 from HuggingFace Hub. + """Retry ``func(*args, **kwargs)`` on HTTP 429, honoring ``Retry-After``. - Behavior: - - Honors the ``Retry-After`` response header when present. - - Otherwise uses exponential backoff with ±20% jitter, capped at - ``_HF_RETRY_MAX_DELAY_SECONDS``. - - Only retries on 429; all other exceptions propagate immediately. + Exponential backoff with ±20% jitter, capped at ``_HF_RETRY_MAX_DELAY_SECONDS``. + Non-429 exceptions propagate immediately. """ for attempt in range(max_attempts): try: From 4f6c89258fc8e248257e1e6baa62e5a04fe47c3f Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Thu, 21 May 2026 11:43:07 -0500 Subject: [PATCH 4/4] Format check cleanup --- tests/unit/model_bridge/test_key_analysis.py | 1 - .../unit/model_bridge/test_weight_processing_adapter_paths.py | 1 - tests/unit/utilities/test_hf_utils.py | 4 +--- transformer_lens/utilities/hf_utils.py | 2 +- 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/unit/model_bridge/test_key_analysis.py b/tests/unit/model_bridge/test_key_analysis.py index e2353cdfb..37ac1bacb 100644 --- a/tests/unit/model_bridge/test_key_analysis.py +++ b/tests/unit/model_bridge/test_key_analysis.py @@ -6,7 +6,6 @@ from transformers import GPT2LMHeadModel -from transformer_lens import HookedTransformer from transformer_lens.weight_processing import ProcessWeights diff --git a/tests/unit/model_bridge/test_weight_processing_adapter_paths.py b/tests/unit/model_bridge/test_weight_processing_adapter_paths.py index 28c64fa60..b0d72fc50 100644 --- a/tests/unit/model_bridge/test_weight_processing_adapter_paths.py +++ b/tests/unit/model_bridge/test_weight_processing_adapter_paths.py @@ -8,7 +8,6 @@ import torch from transformers import GPT2LMHeadModel -from transformer_lens import HookedTransformer from transformer_lens import utilities as utils from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter diff --git a/tests/unit/utilities/test_hf_utils.py b/tests/unit/utilities/test_hf_utils.py index 7d506805f..157af24fc 100644 --- a/tests/unit/utilities/test_hf_utils.py +++ b/tests/unit/utilities/test_hf_utils.py @@ -137,9 +137,7 @@ def test_session_fixture_wraps_autoconfig(self) -> None: def test_session_fixture_wraps_autotokenizer(self) -> None: from transformers import AutoTokenizer - assert getattr( - AutoTokenizer.from_pretrained, hf_utils._TL_RETRY_WRAPPED_ATTR, False - ) + assert getattr(AutoTokenizer.from_pretrained, hf_utils._TL_RETRY_WRAPPED_ATTR, False) def test_idempotent(self) -> None: """A second enable_hf_retry call must not re-wrap (or otherwise break) the classes.""" diff --git a/transformer_lens/utilities/hf_utils.py b/transformer_lens/utilities/hf_utils.py index be515cb72..4b405770d 100644 --- a/transformer_lens/utilities/hf_utils.py +++ b/transformer_lens/utilities/hf_utils.py @@ -81,7 +81,7 @@ def _wrapped(klass, *args: Any, _orig: Any = underlying, **kwargs: Any) -> Any: return call_hf_with_retry(_orig, klass, *args, **kwargs) setattr(_wrapped, _TL_RETRY_WRAPPED_ATTR, True) - cls.from_pretrained = classmethod(_wrapped) + cls.from_pretrained = classmethod(_wrapped) # type: ignore[method-assign,assignment] def call_hf_with_retry(