diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index aa56e78f3..555d84c5a 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable from dataclasses import asdict, dataclass -from typing import Final, Literal, ParamSpec, TypeVar +from typing import Any, Final, Literal, ParamSpec, TypeVar from pyrecest.backend_support._pytorch_allclose_device_contract import ( patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract, @@ -101,6 +101,8 @@ def decorator(obj: Callable[P, R]) -> Callable[P, R]: def get_public_api_status(name: str) -> PublicAPIStatus | None: """Return registered stability metadata for a public API name.""" + if not isinstance(name, str): + return None return _PUBLIC_API_STATUS.get(name) diff --git a/tests/test_stability.py b/tests/test_stability.py index 5aeb2bf81..b8963bf06 100644 --- a/tests/test_stability.py +++ b/tests/test_stability.py @@ -42,6 +42,10 @@ def example_function(): assert status.notes == "example" +def test_get_public_api_status_returns_none_for_non_string_names(): + assert get_public_api_status(["KalmanFilter"]) is None + + def test_registered_public_api_status_rows_have_valid_levels(): rows = list(iter_public_api_status())